#pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch::jit::tensorexpr { // A class that analyzes the given program relevant for Cuda backends. class CudaAnalysis : public IRVisitor { public: CudaAnalysis() { gpu_block_extents_ = {alloc(1), alloc(1), alloc(1)}; gpu_thread_extents_ = { alloc(1), alloc(1), alloc(1)}; } bool is_buf_store_target(const BufPtr& buf) const { return store_targets_.count(buf) > 0; } const std::unordered_set& thread_local_bufs() const { return thread_local_bufs_; } const std::unordered_set& cross_block_bufs() const { return cross_block_bufs_; } const std::vector& gpu_block_extents() const { return gpu_block_extents_; } const std::vector& gpu_thread_extents() const { return gpu_thread_extents_; } private: void visit(const StorePtr& v) override { store_targets_.insert(v->buf()); } void visit(const AllocatePtr& v) override; void visit(const FreePtr& v) override; void visit(const PlacementAllocatePtr& v) override; void visit(const ForPtr& v) override; std::unordered_set store_targets_; std::unordered_set thread_local_bufs_; std::unordered_set cross_block_bufs_; std::vector gpu_block_extents_; std::vector gpu_thread_extents_; }; // An IRMutator that replaces binding loop options with Cuda metavars, and masks // statements blocks which should execute with less reach than the launch // parameter extent. // // We do this by segmenting each block into chunks which should have the same // execution parameters, then if those params differ from the max mask each dim. class GPUMetaVarRewriter : public IRMutator { public: explicit GPUMetaVarRewriter(const CudaAnalysis* cuda_analysis) : cuda_analysis_(cuda_analysis) { gpu_block_vars_ = { alloc("blockIdx.x", kInt), alloc("blockIdx.y", kInt), alloc("blockIdx.z", kInt)}; gpu_thread_vars_ = { alloc("threadIdx.x", kInt), alloc("threadIdx.y", kInt), alloc("threadIdx.z", kInt)}; current_block_reach_ = { alloc(1), alloc(1), alloc(1)}; current_thread_reach_ = { alloc(1), alloc(1), alloc(1)}; } StmtPtr mutate(const ForPtr& v) override; StmtPtr mutate(const BlockPtr& v) override; const std::vector& gpu_block_vars() const { return gpu_block_vars_; } const std::vector& gpu_thread_vars() const { return gpu_thread_vars_; } const std::vector& gpu_block_extents() const { return cuda_analysis_->gpu_block_extents(); } const std::vector& gpu_thread_extents() const { return cuda_analysis_->gpu_thread_extents(); } private: // When processing a block, stores the contents of each sub-segment. class Segment { public: void reset(bool mask) { stmts_.clear(); mask_ = mask; } bool empty() const { return stmts_.empty(); } std::vector& stmts() { return stmts_; } bool mask() { return mask_; } private: std::vector stmts_; bool mask_{true}; }; // Returns true if the current execution scope is equivalent to the launch // parameters. bool isFullExtent(); std::vector gpu_block_vars_; std::vector gpu_thread_vars_; std::vector current_block_reach_; std::vector current_thread_reach_; const CudaAnalysis* cuda_analysis_; }; // A class that overrides the underlying IRPrinter to produce Cuda C. class CudaPrinter : public IRPrinter { public: explicit CudaPrinter( std::ostream* os, const CudaAnalysis* cuda_analysis, bool has_random) : IRPrinter(*os), cuda_analysis_(cuda_analysis) { if (has_random) { rand_func_ = alloc("rand", kHandle); } } void visit(const CastPtr& v) override; void visit(const IntrinsicsPtr& v) override; void visit(const ForPtr& v) override; void visit(const LoadPtr& v) override; void visit(const StorePtr& v) override; void visit(const AtomicAddPtr& v) override; void visit(const MaxPtr& v) override; void visit(const MinPtr& v) override; void visit(const IfThenElsePtr& v) override; void visit(const BlockPtr& v) override; void visit(const AllocatePtr& v) override; void visit(const FreePtr& v) override; void visit(const LetPtr& v) override; void visit(const ExternalCallPtr& v) override; VarPtr rand_func() const { return rand_func_; } std::string dtypeToCppString(const Dtype& dtype) override; using IRPrinter::name_manager; using IRPrinter::visit; private: VarPtr rand_func_; const CudaAnalysis* cuda_analysis_; void print_flat_alloc(const AllocatePtr& alloc); }; // Construct Cuda C from the buffer and tensor input, and invoke the // kernel when real arguments are provided. class TORCH_CUDA_CU_API CudaCodeGen : public CodeGen { public: template CudaCodeGen(StmtPtr stmt, Ts... ts) : CodeGen( stmt, std::vector({BufferArg(ts)...}), at::Device(at::kCUDA, at::cuda::current_device())) { Initialize(); } CudaCodeGen( StmtPtr stmt, const std::vector& buffer_args, at::Device device = at::Device(at::kCUDA, at::cuda::current_device()), const std::string& kernel_func_name = "func") : CodeGen(std::move(stmt), buffer_args, device, kernel_func_name) { Initialize(); } ~CudaCodeGen() override; void call(const std::vector& args) override; void call_raw(const std::vector& args) override; void call_with_numel(void** args, int64_t numel) override; template void operator()(const Ts&... ts) { call(std::vector({CallArg(ts)...})); } at::Tensor empty_strided( c10::IntArrayRef size, c10::IntArrayRef stride, std::optional dtype_opt, std::optional layout_opt, std::optional device_opt, std::optional pin_memory_opt) override; const std::vector& gpu_block_extents() const { return cuda_analysis_->gpu_block_extents(); } const std::vector& gpu_thread_extents() const { return cuda_analysis_->gpu_thread_extents(); } std::string getCodeText(const std::string& attr = "") override { return oss_.str(); } private: void Initialize(); void CompileToNVRTC(const std::string& code, const std::string& func_name); UniqueNameManager* name_manager() { if (!printer_) { throw std::runtime_error("Null IRPrinter is not expected"); } return printer_->name_manager(); } std::ostream& os() { return printer_->os(); } std::ostringstream oss_; std::unique_ptr printer_; std::unique_ptr cuda_analysis_; std::unique_ptr metavar_rewriter_; std::unordered_set taken_func_names; std::mutex eval_lock_; CUfunction function_{nullptr}; bool has_random_ = false; int thread_block_size_ = -1; std::vector arg_pos_in_extents_; #ifdef TORCH_ENABLE_LLVM std::vector> block_extents_eval_; std::vector> thread_extents_eval_; #else std::vector> block_extents_eval_; std::vector> thread_extents_eval_; #endif std::string GetUniqueFuncName(const std::string& func_prefix); }; } // namespace torch::jit::tensorexpr