• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/runtime/jit_executable.h"
17 
18 #include <memory>
19 #include <string_view>
20 #include <utility>
21 
22 #include "llvm/ADT/Optional.h"
23 #include "tensorflow/compiler/xla/mlir/utils/runtime/constraints.h"
24 
25 namespace xla {
26 namespace runtime {
27 
28 using llvm::cast;
29 using llvm::dyn_cast;
30 using llvm::isa;
31 
32 using llvm::ArrayRef;
33 using llvm::ErrorOr;
34 using llvm::Expected;
35 using llvm::Optional;
36 
37 using tfrt::MakeAvailableAsyncValueRef;
38 using tfrt::MakeErrorAsyncValueRef;
39 using tfrt::MakeStringError;
40 
41 using Specialization = JitExecutable::Specialization;
42 
IsSpecializationOnly(ArrayRef<ArgumentConstraint> constraints)43 static bool IsSpecializationOnly(ArrayRef<ArgumentConstraint> constraints) {
44   return llvm::any_of(constraints, [](ArgumentConstraint constraint) {
45     return constraint != ArgumentConstraint::kResolved;
46   });
47 }
48 
HasValueConstraints(ArrayRef<ArgumentConstraint> constraints)49 static bool HasValueConstraints(ArrayRef<ArgumentConstraint> constraints) {
50   return llvm::any_of(constraints, [](ArgumentConstraint constraint) {
51     return constraint == ArgumentConstraint::kValue;
52   });
53 }
54 
55 // Returns true if all function operands have statically known shape.
HasStaticShapeOperands(const FunctionType & signature)56 static bool HasStaticShapeOperands(const FunctionType& signature) {
57   auto is_dynamic = [](ArrayRef<int64_t> sizes) -> bool {
58     return llvm::any_of(sizes, mlir::ShapedType::isDynamic);
59   };
60 
61   for (unsigned i = 0; i < signature.num_operands(); ++i) {
62     const Type* type = signature.operand(i);
63 
64     // Get the underlying value type from the async value.
65     while (auto* value = dyn_cast<AsyncValueType>(type))
66       type = &value->value_type();
67 
68     // Unranked types do not have statically known shape.
69     if (isa<UnrankedTensorType, UnrankedMemrefType>(type)) return false;
70 
71     // For ranked memrefs and tensors check known sizes.
72     if (auto* memref = dyn_cast<MemrefType>(type))
73       if (is_dynamic(memref->sizes())) return false;
74     if (auto* tensor = dyn_cast<RankedTensorType>(type))
75       if (is_dynamic(tensor->sizes())) return false;
76 
77     // All other types are non-shaped and thus have "statically known shape".
78 
79     // TODO(ezhulenev): Run time types might need to support type interfaces or
80     // a hierarchy with a base `ShapedType` so that users can define their own
81     // types that can participate in shape specialization. This becomes
82     // complicated for container-like types (e.g. async value) that might
83     // contain a nested type that is shaped (e.g. memref). For now only the
84     // canonical types can participate in shape specialization.
85   }
86 
87   return true;
88 }
89 
InlineCompilationTaskRunner(size_t num_specializations,ArrayRef<ArgumentConstraint> constraints,ArgumentsRef arguments,CompilationTask task,UserData user_data)90 /*static*/ void JitExecutable::InlineCompilationTaskRunner(
91     size_t num_specializations, ArrayRef<ArgumentConstraint> constraints,
92     ArgumentsRef arguments, CompilationTask task, UserData user_data) {
93   task();
94 }
95 
Instantiate(std::string_view mlir_module,std::string_view entrypoint,Options opts,std::string_view memory_region_name,CompilationTaskRunner runner)96 /*static*/ Expected<JitExecutable> JitExecutable::Instantiate(
97     std::string_view mlir_module, std::string_view entrypoint, Options opts,
98     std::string_view memory_region_name, CompilationTaskRunner runner) {
99   // Try to instantiate compilation context from the mlir source.
100   Expected<std::unique_ptr<JitCompiler>> compiler =
101       JitCompiler::Instantiate(opts.compiler, mlir_module, entrypoint);
102   if (auto err = compiler.takeError()) return std::move(err);
103 
104   // Get resolved operands constraints for the entrypoint function.
105   auto constraints = GetArgumentsConstraints((*compiler)->entrypoint());
106   if (auto err = constraints.takeError()) return std::move(err);
107 
108   // Get the entrypoint function signature, it will be later required to
109   // compute the specialized function signature from the operands at runtime.
110   auto signature = opts.compiler.type_converter.Convert(
111       (*compiler)->entrypoint().getFunctionType());
112   if (auto err = signature.takeError()) return std::move(err);
113 
114   // If all of the operands have static shape, then we can always use default
115   // binary for execution (unless specialization is explicitly required by the
116   // operands constraints).
117   if (HasStaticShapeOperands(*signature) && !IsSpecializationOnly(*constraints))
118     opts.specialization = Specialization::kDisabled;
119 
120   // Return an error if specialization is explicitly disabled, yet some of
121   // the operands have unresolved constraints.
122   if (opts.specialization == Specialization::kDisabled &&
123       IsSpecializationOnly(*constraints))
124     return MakeStringError(
125         "compilation options disabled specialization, yet operands have "
126         "unresolved constraints: ",
127         *constraints);
128 
129   // If the module must be specialized, return JitExecutable without a default
130   // compiled executable.
131   if (opts.specialization == Specialization::kAlways ||
132       IsSpecializationOnly(*constraints))
133     return JitExecutable(mlir_module, entrypoint, memory_region_name,
134                          std::move(opts), std::move(*constraints),
135                          std::move(*signature),
136                          /*default_executable=*/llvm::None, std::move(runner));
137 
138   // Otherwise try to compile the default executable.
139   Expected<Executable> executable =
140       JitCompiler::Compile(std::move(*compiler), memory_region_name);
141   if (auto err = executable.takeError()) return std::move(err);
142 
143   return JitExecutable(mlir_module, entrypoint, memory_region_name,
144                        std::move(opts), std::move(*constraints),
145                        std::move(*signature), std::move(*executable),
146                        std::move(runner));
147 }
148 
JitExecutable(std::string_view mlir_module,std::string_view entrypoint,std::string_view memory_region_name,Options opts,ArrayRef<ArgumentConstraint> constraints,FunctionType signature,Optional<Executable> default_executable,CompilationTaskRunner runner)149 JitExecutable::JitExecutable(std::string_view mlir_module,
150                              std::string_view entrypoint,
151                              std::string_view memory_region_name, Options opts,
152                              ArrayRef<ArgumentConstraint> constraints,
153                              FunctionType signature,
154                              Optional<Executable> default_executable,
155                              CompilationTaskRunner runner)
156     : mlir_module_(mlir_module),
157       entrypoint_(entrypoint),
158       memory_region_name_(memory_region_name),
159       opts_(std::move(opts)),
160       constraints_(constraints.begin(), constraints.end()),
161       has_value_constraints_(HasValueConstraints(constraints_)),
162       signature_(std::move(signature)),
163       symbolic_shapes_resolver_(signature_, constraints_),
164       has_default_executable_(default_executable.has_value()),
165       runner_(std::move(runner)),
166       specializations_(std::make_unique<Specializations>()) {
167   // Initialize default executable if it is available.
168   if (has_default_executable_) {
169     default_executable_ =
170         MakeAvailableAsyncValueRef<Executable>(std::move(*default_executable));
171   } else {
172     default_executable_ =
173         MakeErrorAsyncValueRef("default executable is not available");
174   }
175 }
176 
DefaultExecutable() const177 AsyncValuePtr<Executable> JitExecutable::DefaultExecutable() const {
178   return default_executable_.AsPtr();
179 }
180 
constraints() const181 ArrayRef<ArgumentConstraint> JitExecutable::constraints() const {
182   return constraints_;
183 }
184 
185 // Combines `hash` with a hash value computed from a value constrained operands.
CombineWithValueConstraineOperands(llvm::hash_code hash,ArgumentsRef arguments,ArrayRef<ArgumentConstraint> constraints)186 static llvm::hash_code CombineWithValueConstraineOperands(
187     llvm::hash_code hash, ArgumentsRef arguments,
188     ArrayRef<ArgumentConstraint> constraints) {
189   for (int i = 0; i < constraints.size(); ++i) {
190     if (LLVM_LIKELY(constraints[i] != ArgumentConstraint::kValue)) continue;
191 
192     // TODO(ezhulenev): Currently we only support value specialization of Tensor
193     // operands (with MemrefDesc run time argument), it should be extended to
194     // support open type and argument hierarchies.
195     const MemrefDesc& memref = cast<MemrefDesc>(arguments[i]);
196     const auto* data = static_cast<uint8_t*>(memref.data());
197     size_t rank = memref.rank();
198     assert(rank == 0 || rank == 1);
199     size_t num_values = rank == 0 ? 1 : memref.size(0);
200     int64_t len = num_values * GetHostSize(memref.dtype());
201     hash = llvm::hash_combine(hash, llvm::hash_combine_range(data, data + len));
202   }
203   return hash;
204 }
205 
206 // TODO(ezhulenev): The fast path should be free of mutex to find the
207 // pre-compiled specialization. Maybe use atomic pointers (multiple atomic
208 // pointers?) to keep the most commonly used specialization available without
209 // doing a lookup in the AsyncValuesCache.
210 //
211 // TODO(ezhulenev): The number of specializations should be bounded, ideally we
212 // should only keep N most common specializations, and for everything else
213 // fall back on the default executable. However what to do if default executable
214 // is not available, and the number of specializations is above N?
GetExecutable(ArgumentsRef arguments,UserData user_data,const SpecializationListener * listener)215 Expected<AsyncValuePtr<Executable>> JitExecutable::GetExecutable(
216     ArgumentsRef arguments, UserData user_data,
217     const SpecializationListener* listener) {
218   // Do not try to compile specialized executable if it is explicitly disabled.
219   if (opts_.specialization == Specialization::kDisabled)
220     return DefaultExecutable();
221 
222   // The number of arguments must match the entrypoint signature.
223   if (LLVM_UNLIKELY(arguments.size() != signature_.num_operands()))
224     return MakeStringError("expected ", signature_.num_operands(),
225                            " arguments, got: ", arguments.size());
226 
227   // Resolve symbolic shapes hash based on the static and runtime information.
228   //
229   // We rely on the hash code to find the specialized executable. In case of
230   // a collision (practically impossible) incompatible arguments will be
231   // rejected by the executable arguments verification.
232   ErrorOr<llvm::hash_code> hash =
233       symbolic_shapes_resolver_.ResolveHash(arguments);
234 
235   // If we failed to resolve the symbolic shapes hash, then we need to verify
236   // all the operands to find the mismatch and report it to the user.
237   if (LLVM_UNLIKELY(hash.getError())) {
238     for (unsigned i = 0; i < arguments.size(); ++i) {
239       auto* type = signature_.operand(i);
240 
241       // TODO(ezhulenev): Support open shaped type/argument hierarchy.
242       auto* memref_arg = dyn_cast<MemrefDesc>(&arguments[i]);
243       if (!memref_arg) continue;
244 
245       if (auto* memref = dyn_cast<MemrefType>(type)) {
246         if (auto err = VerifyMemrefArgument(i, *memref, *memref_arg))
247           return std::move(err);
248 
249       } else if (auto* tensor = dyn_cast<RankedTensorType>(type)) {
250         if (auto err = VerifyMemrefArgument(i, *tensor, *memref_arg))
251           return std::move(err);
252 
253       } else {
254         return MakeStringError("expected shaped operand at #", i,
255                                ", got: ", *signature_.operand(i));
256       }
257     }
258 
259     assert(false && "failed to detect incorrect operand");
260     return MakeStringError("failed to resolve symbolic shapes");
261   }
262 
263   // Combine with a hash value computed from the value constrained operands.
264   if (LLVM_UNLIKELY(has_value_constraints_))
265     *hash = CombineWithValueConstraineOperands(*hash, arguments, constraints_);
266 
267   // Maybe return Executable from the cache.
268   if (auto cached = specializations_->Find(*hash)) {
269     // Always use specialized kernel if required by the compilation options.
270     if (opts_.specialization == Specialization::kAlways) return cached;
271 
272     // Fall back on default executable if the specialization is not yet
273     // available.
274     if (has_default_executable_ && !cached.IsAvailable())
275       return DefaultExecutable();
276 
277     return cached;
278   }
279 
280   // Instantiation from the source and specialization are cheap, so we do it in
281   // the caller thread. We only use compilation runner for expensive part.
282 
283   // Try to instantiate compilation context from the mlir source.
284   Expected<std::unique_ptr<JitCompiler>> compiler =
285       JitCompiler::Instantiate(opts_.compiler, mlir_module_, entrypoint_);
286 
287   if (auto err = compiler.takeError()) {
288     assert(false && "parsing mlir module must always succeed at this point");
289     return std::move(err);
290   }
291 
292   // Specialize executable to the concrete operands.
293   ErrorOr<llvm::SmallVector<SymbolicShapesResolver::SymbolicShape>>
294       symbolic_shapes = symbolic_shapes_resolver_.Resolve(arguments);
295   if (auto err = (*compiler)->Specialize(arguments, *symbolic_shapes,
296                                          constraints_, listener)) {
297     return MakeStringError("failed to specialize executable: ", err);
298   }
299 
300   // Allocate a placeholder for the compiled specialization only after we are
301   // ready to dispatch the compilation task.
302   Specializations::Entry entry = specializations_->Allocate(*hash);
303 
304   // We lost the race; some other invocation will do the compilation.
305   if (!entry.allocated) return entry.ptr;
306 
307   // Get the specialization id from the size of the specializations cache.
308   size_t specialization = entry.size - 1;
309 
310   // Construct the task that will do the specialized executable compilation.
311   auto compile = CompilationTask(
312       [compiler = std::move(*compiler), ref = entry.ptr.CopyRef(),
313        memory_region_name = memory_region_name_, specialization]() mutable {
314         Expected<Executable> executable = JitCompiler::Compile(
315             std::move(compiler), memory_region_name, specialization);
316 
317         // Set the allocated entry async value state to error or concrete.
318         if (auto err = executable.takeError()) {
319           ref.SetError(std::move(err));
320         } else {
321           ref.emplace(std::move(*executable));
322         }
323       });
324 
325   // Offload specialization compilation to the user provided runner.
326   runner_(specialization, constraints_, arguments, std::move(compile),
327           user_data);
328 
329   // Use the default executable while we are compiling a specialized version if
330   // this is not explicitly disabled by the compilation options.
331   if (opts_.specialization == Specialization::kAlways)
332     return entry.ptr;
333   else
334     return has_default_executable_ ? DefaultExecutable() : entry.ptr;
335 }
336 
AllExecutablesCompiled() const337 AsyncValueRef<Chain> JitExecutable::AllExecutablesCompiled() const {
338   return specializations_->AllAvailable();
339 }
340 
341 }  // namespace runtime
342 }  // namespace xla
343