• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SIMPLE_ORC_JIT_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SIMPLE_ORC_JIT_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "llvm/ADT/Triple.h"
24 #include "llvm/ExecutionEngine/JITEventListener.h"
25 #include "llvm/ExecutionEngine/Orc/Core.h"
26 #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
27 #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
28 #include "llvm/ExecutionEngine/Orc/SymbolStringPool.h"
29 #include "llvm/IR/Module.h"
30 #include "llvm/Target/TargetMachine.h"
31 #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
32 #include "tensorflow/compiler/xla/types.h"
33 
34 namespace xla {
35 namespace cpu {
36 
37 // Simplified LLVM JIT based on the new Orc API.
38 //
39 // This class wraps Orc's functionality into a single interface that only
40 // exposes what we need for XLA.
41 //
42 // Supports JIT-ing multiple modules but without cross-module linking.
43 // Implements eager compilation - the module is lowered to binary as soon as
44 // it's added to the JIT.
45 class SimpleOrcJIT {
46  public:
47   using ObjLayerT = llvm::orc::LegacyRTDyldObjectLinkingLayer;
48   using CompileFtor = std::function<ObjLayerT::ObjectPtr(llvm::Module&)>;
49   using CompileLayerT = llvm::orc::LegacyIRCompileLayer<ObjLayerT, CompileFtor>;
50   using VModuleKeyT = llvm::orc::VModuleKey;
51 
52   // Create a new JIT, targeting the host architecture.
53   //
54   // {pre,post}_optimization_hook is invoked on the module before/after all
55   // LLVM IR-level optimizations.  post_codegen_hook is invoked after
56   // compiling to machine code.
57   SimpleOrcJIT(
58       const llvm::TargetOptions& target_options,
59       llvm::CodeGenOpt::Level opt_level, bool optimize_for_size,
60       bool enable_fast_math, bool disable_expensive_passes,
61       LLVMCompiler::ModuleHook pre_optimization_hook,
62       LLVMCompiler::ModuleHook post_optimization_hook,
63       std::function<void(const llvm::object::ObjectFile&)> post_codegen_hook);
64 
data_layout()65   const llvm::DataLayout& data_layout() const { return data_layout_; }
66 
target_triple()67   const llvm::Triple& target_triple() const {
68     return target_machine_->getTargetTriple();
69   }
70 
71   // Add a module to the JIT. Returns an opaque key that can be used to later
72   // remove this module.
73   VModuleKeyT AddModule(std::unique_ptr<llvm::Module> module);
74 
75   // Remove a module from the JIT and free the memory associated with it.
76   void RemoveModule(VModuleKeyT key);
77 
78   // Get the runtime address of the compiled symbol whose name is given. Returns
79   // nullptr if the symbol cannot be found.
80   llvm::JITSymbol FindCompiledSymbol(const std::string& name);
81 
target_machine()82   llvm::TargetMachine* target_machine() const { return target_machine_.get(); }
83 
84   // Creates an llvm::TargetMachine suitable for JITting code that will run on
85   // the current machine.
86   static std::unique_ptr<llvm::TargetMachine> InferTargetMachineForJIT(
87       const llvm::TargetOptions& target_options,
88       llvm::CodeGenOpt::Level opt_level);
89 
90  private:
91   llvm::JITSymbol ResolveRuntimeSymbol(const std::string& name);
92 
93   void NotifyObjectFinalized(
94       const llvm::object::ObjectFile& object,
95       const llvm::RuntimeDyld::LoadedObjectInfo& object_info);
96   void NotifyObjectFreed(const llvm::object::ObjectFile& object);
97 
98   std::vector<VModuleKeyT> module_keys_;
99   std::unique_ptr<llvm::TargetMachine> target_machine_;
100   const llvm::DataLayout data_layout_;
101   llvm::orc::ExecutionSession execution_session_;
102   std::shared_ptr<llvm::orc::SymbolResolver> symbol_resolver_;
103   ObjLayerT object_layer_;
104   CompileLayerT compile_layer_;
105 
106   // Non owning pointer to a JIT event listener that registers the JIT events
107   // with an attached GDB.
108   //
109   // Note: we get a pointer to this event listener using
110   // `createGDBRegistrationListener` which makes it look like we're supposed to
111   // free this, but the function is poorly named and really just returns a
112   // pointer to a static object.
113   llvm::JITEventListener* gdb_jit_event_listener_;
114 };
115 
116 }  // namespace cpu
117 }  // namespace xla
118 
119 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SIMPLE_ORC_JIT_H_
120