#pragma once #ifdef TORCH_ENABLE_LLVM #include #include #include #include #include #include #include namespace torch { namespace jit { namespace tensorexpr { class LLVMCodeGenImpl; class LLVMCodeGenCallee; class TORCH_API LLVMCodeGen : public CodeGen { public: explicit LLVMCodeGen( StmtPtr stmt, const std::vector& args, at::Device device = at::kCPU, const std::string& kernel_func_name = "func", Dtype dtype = kInt, std::optional triple = std::nullopt, std::optional cpu = std::nullopt, std::optional attrs = std::nullopt); explicit LLVMCodeGen(StmtPtr stmt); LLVMCodeGen() = delete; ~LLVMCodeGen() override; // Cleans up all the memory used during LLVM code generation pass except // the generated kernel. After calling this method, users should not call // methods like `getCodeText` that require the LLVMCodeGenImpl data. However, // users can continue to call this kernel using `call` and `call_raw`. void cleanup_memory(); TORCH_API void call(const std::vector& args) override; TORCH_API void call_raw(const std::vector& args) override; TORCH_API void call_with_numel(void** args, int64_t numel) override; at::Tensor empty_strided( c10::IntArrayRef size, c10::IntArrayRef stride, std::optional dtype_opt, std::optional layout_opt, std::optional device_opt, std::optional pin_memory_opt) override; template T value() { return value(nullptr); } template T value(std::vector& args) { return value(args.data()); } template T value(void** args) { T (*fp)(void**) = (T(*)(void**))getKernelAddress(callee_.get()); T rv = fp(args); return rv; } std::string getCodeText(const std::string& attr = "") override; private: void* getKernelAddress(LLVMCodeGenCallee* callee); std::unique_ptr callee_; std::unique_ptr impl_; }; struct TORCH_API LLVMCodeGenBuilder { using BufferArg = CodeGen::BufferArg; LLVMCodeGenBuilder(StmtPtr stmt, std::vector args) : stmt_(stmt), args_(std::move(args)) {} LLVMCodeGenBuilder& device(at::Device device) { device_ = device; return *this; } LLVMCodeGenBuilder& kernelFuncName(std::string name) { kernelFuncName_ = std::move(name); return *this; } LLVMCodeGenBuilder& dtype(Dtype d) { dtype_ = d; return *this; } LLVMCodeGenBuilder& triple(std::string triple) { triple_ = std::move(triple); return *this; } LLVMCodeGenBuilder& cpu(std::string cpu) { cpu_ = std::move(cpu); return *this; } LLVMCodeGenBuilder& attrs(std::string attrs) { attrs_ = std::move(attrs); return *this; } std::unique_ptr build() { return std::make_unique( stmt_, args_, device_, kernelFuncName_, dtype_, triple_, cpu_, attrs_); } private: StmtPtr stmt_; std::vector args_; at::Device device_ = at::kCPU; std::string kernelFuncName_ = "func"; Dtype dtype_ = kInt; std::optional triple_ = std::nullopt; std::optional cpu_ = std::nullopt; std::optional attrs_ = std::nullopt; }; TORCH_API std::optional& LLVMTargetTriple(); TORCH_API std::optional& LLVMTargetCPU(); TORCH_API std::optional& LLVMTargetAttrs(); TORCH_API bool& LLVMAOTWorkflow(); } // namespace tensorexpr } // namespace jit } // namespace torch #endif // TORCH_ENABLE_LLVM