1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
17
18 #include <stdint.h>
19 #include <algorithm>
20 #include <list>
21 #include <utility>
22
23 #include "absl/memory/memory.h"
24 #include "llvm/ExecutionEngine/ExecutionEngine.h"
25 #include "llvm/ExecutionEngine/JITSymbol.h"
26 #include "llvm/ExecutionEngine/SectionMemoryManager.h"
27 #include "llvm/IR/Mangler.h"
28 #include "llvm/Support/CodeGen.h"
29 #include "llvm/Support/Host.h"
30 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
31 #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
32 #include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h"
33 #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h"
34 #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h"
35 #include "tensorflow/compiler/xla/service/cpu/runtime_fft.h"
36 #include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h"
37 #include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h"
38 #include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h"
39 #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
40 #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
41 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h"
42 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h"
43 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
44 #include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h"
45 #include "tensorflow/compiler/xla/types.h"
46 #include "tensorflow/core/platform/logging.h"
47
48 namespace xla {
49 namespace cpu {
50 namespace {
51
DetectMachineAttributes()52 llvm::SmallVector<std::string, 0> DetectMachineAttributes() {
53 llvm::SmallVector<std::string, 0> result;
54 llvm::StringMap<bool> host_features;
55 if (llvm::sys::getHostCPUFeatures(host_features)) {
56 for (auto& feature : host_features) {
57 if (feature.second) {
58 llvm::StringRef feature_name = feature.first();
59 // Skip avx512 for now, it isn't quite ready in LLVM.
60 if (feature_name.startswith("avx512")) {
61 continue;
62 }
63 result.push_back(feature_name);
64 }
65 }
66 }
67 return result;
68 }
69
GetHostCpuName()70 llvm::StringRef GetHostCpuName() {
71 auto cpu_name = llvm::sys::getHostCPUName();
72 // Skip avx512 for now, it isn't quite ready in LLVM.
73 cpu_name.consume_back("-avx512");
74 return cpu_name;
75 }
76 } // namespace
77
78 /*static*/ std::unique_ptr<llvm::TargetMachine>
InferTargetMachineForJIT(const llvm::TargetOptions & target_options,llvm::CodeGenOpt::Level opt_level)79 SimpleOrcJIT::InferTargetMachineForJIT(
80 const llvm::TargetOptions& target_options,
81 llvm::CodeGenOpt::Level opt_level) {
82 std::unique_ptr<llvm::TargetMachine> target_machine(
83 llvm::EngineBuilder()
84 .setTargetOptions(target_options)
85 .setOptLevel(opt_level)
86 .selectTarget(
87 /*TargetTriple=*/llvm::Triple(), /*MArch=*/"",
88 /*MCPU=*/GetHostCpuName(),
89 /*MAttrs=*/DetectMachineAttributes()));
90 CHECK(target_machine != nullptr);
91 return target_machine;
92 }
93
SimpleOrcJIT(const llvm::TargetOptions & target_options,llvm::CodeGenOpt::Level opt_level,bool optimize_for_size,bool enable_fast_math,bool disable_expensive_passes,LLVMCompiler::ModuleHook pre_optimization_hook,LLVMCompiler::ModuleHook post_optimization_hook,std::function<void (const llvm::object::ObjectFile &)> post_codegen_hook)94 SimpleOrcJIT::SimpleOrcJIT(
95 const llvm::TargetOptions& target_options,
96 llvm::CodeGenOpt::Level opt_level, bool optimize_for_size,
97 bool enable_fast_math, bool disable_expensive_passes,
98 LLVMCompiler::ModuleHook pre_optimization_hook,
99 LLVMCompiler::ModuleHook post_optimization_hook,
100 std::function<void(const llvm::object::ObjectFile&)> post_codegen_hook)
101 : target_machine_(InferTargetMachineForJIT(target_options, opt_level)),
102 data_layout_(target_machine_->createDataLayout()),
103 symbol_resolver_(llvm::orc::createLegacyLookupResolver(
104 execution_session_,
105 [this](const std::string& name) -> llvm::JITSymbol {
106 return this->ResolveRuntimeSymbol(name);
107 },
__anon103e137a0302(llvm::Error Err) 108 [](llvm::Error Err) {
109 cantFail(std::move(Err), "lookupFlags failed");
110 })),
111 object_layer_(
112 execution_session_,
__anon103e137a0402(llvm::orc::VModuleKey) 113 [this](llvm::orc::VModuleKey) {
114 llvm::orc::LegacyRTDyldObjectLinkingLayer::Resources result;
115 result.MemMgr = std::make_shared<llvm::SectionMemoryManager>(
116 orc_jit_memory_mapper::GetInstance());
117 result.Resolver = symbol_resolver_;
118 return result;
119 },
120 /*NotifyLoaded=*/
121 llvm::orc::LegacyRTDyldObjectLinkingLayer::NotifyLoadedFtor(),
122 /*NotifyFinalized=*/
123 [this](VModuleKeyT, const llvm::object::ObjectFile& object,
__anon103e137a0502(VModuleKeyT, const llvm::object::ObjectFile& object, const llvm::RuntimeDyld::LoadedObjectInfo& object_info) 124 const llvm::RuntimeDyld::LoadedObjectInfo& object_info) {
125 this->NotifyObjectFinalized(object, object_info);
126 },
127 /*NotifyFreed=*/
__anon103e137a0602(VModuleKeyT, const llvm::object::ObjectFile& object) 128 [this](VModuleKeyT, const llvm::object::ObjectFile& object) {
129 this->NotifyObjectFreed(object);
130 }),
131 compile_layer_(
132 object_layer_,
133 CompilerFunctor(target_machine_.get(), opt_level, optimize_for_size,
134 enable_fast_math, disable_expensive_passes,
135 std::move(pre_optimization_hook),
136 std::move(post_optimization_hook),
137 std::move(post_codegen_hook))),
138 gdb_jit_event_listener_(
139 llvm::JITEventListener::createGDBRegistrationListener()) {
140 VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str()
141 << " features: " << target_machine_->getTargetFeatureString().str();
142 }
143
ResolveRuntimeSymbol(const std::string & name)144 llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) {
145 void* func_addr = nullptr;
146 if (name.size() > 1 && name.front() == data_layout_.getGlobalPrefix()) {
147 // On Mac OS X, 'name' may have a leading underscore prefix, even though the
148 // registered name may not.
149 std::string stripped_name(name.begin() + 1, name.end());
150 func_addr = CustomCallTargetRegistry::Global()->Lookup(stripped_name);
151 } else {
152 func_addr = CustomCallTargetRegistry::Global()->Lookup(name);
153 }
154
155 if (func_addr == nullptr) {
156 LOG(ERROR) << "Unable to resolve runtime symbol: " << name;
157 return nullptr;
158 }
159 llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast<uint64_t>(func_addr),
160 llvm::JITSymbolFlags::None);
161 return symbol_info;
162 }
163
NotifyObjectFinalized(const llvm::object::ObjectFile & object,const llvm::RuntimeDyld::LoadedObjectInfo & object_info)164 void SimpleOrcJIT::NotifyObjectFinalized(
165 const llvm::object::ObjectFile& object,
166 const llvm::RuntimeDyld::LoadedObjectInfo& object_info) {
167 uint64_t key = static_cast<uint64_t>(
168 reinterpret_cast<uintptr_t>(object.getData().data()));
169 gdb_jit_event_listener_->notifyObjectLoaded(key, object, object_info);
170 }
171
NotifyObjectFreed(const llvm::object::ObjectFile & object)172 void SimpleOrcJIT::NotifyObjectFreed(const llvm::object::ObjectFile& object) {
173 uint64_t key = static_cast<uint64_t>(
174 reinterpret_cast<uintptr_t>(object.getData().data()));
175 gdb_jit_event_listener_->notifyFreeingObject(key);
176 }
177
AddModule(std::unique_ptr<llvm::Module> module)178 SimpleOrcJIT::VModuleKeyT SimpleOrcJIT::AddModule(
179 std::unique_ptr<llvm::Module> module) {
180 auto key = execution_session_.allocateVModule();
181 cantFail(compile_layer_.addModule(key, std::move(module)));
182 module_keys_.push_back(key);
183 return key;
184 }
185
RemoveModule(SimpleOrcJIT::VModuleKeyT key)186 void SimpleOrcJIT::RemoveModule(SimpleOrcJIT::VModuleKeyT key) {
187 module_keys_.erase(std::remove(module_keys_.begin(), module_keys_.end(), key),
188 module_keys_.end());
189 cantFail(compile_layer_.removeModule(key));
190 }
191
FindCompiledSymbol(const std::string & name)192 llvm::JITSymbol SimpleOrcJIT::FindCompiledSymbol(const std::string& name) {
193 // Resolve symbol from last module to first, allowing later redefinitions of
194 // symbols shadow earlier ones.
195 for (auto& key :
196 llvm::make_range(module_keys_.rbegin(), module_keys_.rend())) {
197 if (auto symbol =
198 compile_layer_.findSymbolIn(key, name,
199 /*ExportedSymbolsOnly=*/true)) {
200 return symbol;
201 }
202 }
203
204 return nullptr;
205 }
206
207 namespace {
208 // Register some known symbols with the CustomCallTargetRegistry.
RegisterKnownJITSymbols()209 bool RegisterKnownJITSymbols() {
210 CustomCallTargetRegistry* registry = CustomCallTargetRegistry::Global();
211
212 #define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \
213 do { \
214 auto* function_address = \
215 reinterpret_cast<void*>(__xla_cpu_runtime_##base_name); \
216 registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \
217 function_address); \
218 CHECK_EQ(absl::string_view(xla::cpu::runtime::k##base_name##SymbolName), \
219 "__xla_cpu_runtime_" #base_name); \
220 } while (false)
221
222 REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue);
223 REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation);
224 REGISTER_CPU_RUNTIME_SYMBOL(MKLConvF32);
225 REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF16);
226 REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF32);
227 REGISTER_CPU_RUNTIME_SYMBOL(EigenFft);
228 REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF16);
229 REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32);
230 REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64);
231 REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF32);
232 REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF64);
233 REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF32);
234 REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF64);
235 REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF16);
236 REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32);
237 REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedFft);
238 REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF16);
239 REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32);
240 REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64);
241 REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin);
242 REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue);
243 REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation);
244 REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSort);
245
246 registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee));
247 registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee));
248
249 #undef REGISTER_CPU_RUNTIME_SYMBOL
250
251 // Register both the f32 (float) and f64 (double) versions of a libm symbol.
252 // Unfortunately the double versions are overloaded on some systems, e.g.
253 // Mac so we need an explicit cast. This requires passing the function signature
254 // for that case.
255 #define REGISTER_LIBM_SYMBOL(name, double_sig) \
256 do { \
257 registry->Register(#name "f", reinterpret_cast<void*>(name##f)); \
258 registry->Register( \
259 #name, reinterpret_cast<void*>(static_cast<double_sig>(name))); \
260 } while (false)
261
262 REGISTER_LIBM_SYMBOL(acos, double (*)(double));
263 REGISTER_LIBM_SYMBOL(acosh, double (*)(double));
264 REGISTER_LIBM_SYMBOL(asin, double (*)(double));
265 REGISTER_LIBM_SYMBOL(asinh, double (*)(double));
266 REGISTER_LIBM_SYMBOL(atan, double (*)(double));
267 REGISTER_LIBM_SYMBOL(atan2, double (*)(double, double));
268 REGISTER_LIBM_SYMBOL(atanh, double (*)(double));
269 REGISTER_LIBM_SYMBOL(cbrt, double (*)(double));
270 REGISTER_LIBM_SYMBOL(ceil, double (*)(double));
271 REGISTER_LIBM_SYMBOL(copysign, double (*)(double, double));
272 REGISTER_LIBM_SYMBOL(cos, double (*)(double));
273 REGISTER_LIBM_SYMBOL(cosh, double (*)(double));
274 REGISTER_LIBM_SYMBOL(erf, double (*)(double));
275 REGISTER_LIBM_SYMBOL(erfc, double (*)(double));
276 REGISTER_LIBM_SYMBOL(exp, double (*)(double));
277 REGISTER_LIBM_SYMBOL(exp2, double (*)(double));
278 REGISTER_LIBM_SYMBOL(expm1, double (*)(double));
279 REGISTER_LIBM_SYMBOL(fabs, double (*)(double));
280 REGISTER_LIBM_SYMBOL(fdim, double (*)(double, double));
281 REGISTER_LIBM_SYMBOL(floor, double (*)(double));
282 REGISTER_LIBM_SYMBOL(fma, double (*)(double, double, double));
283 REGISTER_LIBM_SYMBOL(fmax, double (*)(double, double));
284 REGISTER_LIBM_SYMBOL(fmin, double (*)(double, double));
285 REGISTER_LIBM_SYMBOL(fmod, double (*)(double, double));
286 REGISTER_LIBM_SYMBOL(frexp, double (*)(double, int*));
287 REGISTER_LIBM_SYMBOL(hypot, double (*)(double, double));
288 REGISTER_LIBM_SYMBOL(ilogb, int (*)(double));
289 REGISTER_LIBM_SYMBOL(ldexp, double (*)(double, int));
290 REGISTER_LIBM_SYMBOL(lgamma, double (*)(double));
291 REGISTER_LIBM_SYMBOL(llrint, long long (*)(double)); // NOLINT(runtime/int)
292 REGISTER_LIBM_SYMBOL(llround, long long (*)(double)); // NOLINT(runtime/int)
293 REGISTER_LIBM_SYMBOL(log, double (*)(double));
294 REGISTER_LIBM_SYMBOL(log10, double (*)(double));
295 REGISTER_LIBM_SYMBOL(log1p, double (*)(double));
296 REGISTER_LIBM_SYMBOL(log2, double (*)(double));
297 REGISTER_LIBM_SYMBOL(logb, double (*)(double));
298 REGISTER_LIBM_SYMBOL(lrint, long (*)(double)); // NOLINT(runtime/int)
299 REGISTER_LIBM_SYMBOL(lround, long (*)(double)); // NOLINT(runtime/int)
300 REGISTER_LIBM_SYMBOL(modf, double (*)(double, double*));
301 REGISTER_LIBM_SYMBOL(nan, double (*)(const char*));
302 REGISTER_LIBM_SYMBOL(nearbyint, double (*)(double));
303 REGISTER_LIBM_SYMBOL(nextafter, double (*)(double, double));
304 REGISTER_LIBM_SYMBOL(nexttoward, double (*)(double, long double));
305 REGISTER_LIBM_SYMBOL(pow, double (*)(double, double));
306 REGISTER_LIBM_SYMBOL(remainder, double (*)(double, double));
307 REGISTER_LIBM_SYMBOL(remquo, double (*)(double, double, int*));
308 REGISTER_LIBM_SYMBOL(rint, double (*)(double));
309 REGISTER_LIBM_SYMBOL(round, double (*)(double));
310 REGISTER_LIBM_SYMBOL(scalbln,
311 double (*)(double, long)); // NOLINT(runtime/int)
312 REGISTER_LIBM_SYMBOL(scalbn, double (*)(double, int));
313 REGISTER_LIBM_SYMBOL(sin, double (*)(double));
314 #ifdef __APPLE__
315 REGISTER_LIBM_SYMBOL(__sincos, void (*)(double, double*, double*));
316 registry->Register("__sincosf_stret",
317 reinterpret_cast<void*>(__sincosf_stret));
318 registry->Register("__sincos_stret", reinterpret_cast<void*>(__sincos_stret));
319 #else
320 REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*));
321 #endif
322 REGISTER_LIBM_SYMBOL(sinh, double (*)(double));
323 REGISTER_LIBM_SYMBOL(sqrt, double (*)(double));
324 REGISTER_LIBM_SYMBOL(tan, double (*)(double));
325 REGISTER_LIBM_SYMBOL(tanh, double (*)(double));
326 REGISTER_LIBM_SYMBOL(tgamma, double (*)(double));
327 REGISTER_LIBM_SYMBOL(trunc, double (*)(double));
328
329 #undef REGISTER_LIBM_SYMBOL
330
331 registry->Register("memcpy", reinterpret_cast<void*>(memcpy));
332 registry->Register("memmove", reinterpret_cast<void*>(memmove));
333 registry->Register("memset", reinterpret_cast<void*>(memset));
334
335 #ifdef __APPLE__
336 registry->Register("__bzero", reinterpret_cast<void*>(bzero));
337 registry->Register("memset_pattern16",
338 reinterpret_cast<void*>(memset_pattern16));
339 #endif
340
341 #ifdef MEMORY_SANITIZER
342 registry->Register("__msan_unpoison",
343 reinterpret_cast<void*>(__msan_unpoison));
344 #endif
345
346 return true;
347 }
348
349 bool unused = RegisterKnownJITSymbols();
350 } // namespace
351
352 } // namespace cpu
353 } // namespace xla
354