#pragma once #include #include #include #include namespace torch::jit::tensorexpr { template class PaddedBuffer; class TORCH_API CodeGen { public: class BufferArg; class CallArg; template CodeGen(StmtPtr stmt, Ts... ts) : stmt_(std::move(stmt)), buffer_args_({BufferArg(ts)...}) {} CodeGen( StmtPtr stmt, std::vector buffer_args, at::Device device = at::kCPU, std::string kernel_func_name = "func"); virtual ~CodeGen() = default; StmtPtr stmt() const { return stmt_; } void set_stmt(StmtPtr s) { stmt_ = std::move(s); } void apply_mutator(IRMutator* mutator) { stmt_ = stmt_->accept_mutator(mutator); } void apply_visitor(IRVisitor* visitor) { stmt_->accept(visitor); } std::vector& buffer_args() { return buffer_args_; } const std::vector& buffer_args() const { return buffer_args_; } at::Device device() { return device_; } // This function returns the generated code as // a string. virtual std::string getCodeText( const std::string& attr [[maybe_unused]] = "") { return ""; } // TODO: Figure out how to unify these call interfaces. /// Call a function with a vector of CallArgs, which are tagged /// unions that properly type the arguments. virtual void call(const std::vector& args) = 0; /// Call a function faster than a regular `call` by assuming that /// the generated kernel already knows the type of the arguments, so /// they can be type-punned with `void*`s. virtual void call_raw(const std::vector& args) = 0; /// Call a function even faster than a regular call, by assuming /// that the number of thread blocks can be derived from `numel` via /// a simple division, rather than evaluating an expression. virtual void call_with_numel(void** args, int64_t numel); virtual at::Tensor empty_strided( c10::IntArrayRef size, c10::IntArrayRef stride, std::optional dtype_opt, std::optional layout_opt, std::optional device_opt, std::optional pin_memory_opt) { return at::empty_strided( size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt); } const std::string& kernel_func_name() const { return kernel_func_name_; } void allocIntermediateBufs(); protected: static void* argToPtr(const BufferArg& bufferArg, const CallArg& callArg); private: StmtPtr stmt_; std::vector buffer_args_; at::Device device_ = at::kCPU; std::string kernel_func_name_ = "func"; }; class TORCH_API ExtCallMemoryReuse : public IRMutator { static std::unordered_map makeExtCallFuncNameMap(); static const std::unordered_map extCallFuncNameMap_; public: explicit ExtCallMemoryReuse( const std::vector& bufferArgs); ~ExtCallMemoryReuse() override = default; StmtPtr mutate(const ExternalCallPtr& v) override; private: std::unordered_set bufferArgs_; }; class CodeGen::BufferArg { public: BufferArg(const Tensor& tensor) : buf_(tensor.buf()) {} BufferArg(const VarHandle& var) : var_(var.node()), isVar_(true) {} BufferArg(const BufHandle& buf) : buf_(buf.node()) {} BufferArg(BufPtr buf) : buf_(std::move(buf)) {} VarPtr var() const { return isVar_ ? var_ : buf_->base_handle(); } BufPtr buf() const { return buf_; } bool isVar() const { return isVar_; } Dtype dtype() const { return isVar_ ? var_->dtype() : buf_->dtype(); } private: VarPtr var_ = nullptr; BufPtr buf_ = nullptr; bool isVar_ = false; }; class CodeGen::CallArg { public: template CallArg(const PaddedBuffer& buffer); template CallArg(const std::vector& buffer) : data_(const_cast(buffer.data())) {} CallArg(void* ptr) : data_(ptr) {} #define ARG_TYPE_CTOR(Type, Name) \ CallArg(Type v) { \ memcpy(buffer_, &v, sizeof(Type)); \ data_ = (void*)buffer_; \ } AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, ARG_TYPE_CTOR); #undef ARG_TYPE_CTOR void* data() const { return data_; } CallArg(const CallArg& rhs) { if (rhs.data_ == rhs.buffer_) { memcpy(this->buffer_, rhs.buffer_, sizeof(rhs.buffer_)); this->data_ = (void*)(this->buffer_); } else { this->data_ = rhs.data_; } } CallArg& operator=(const CallArg& rhs) { if (this == &rhs) { return *this; } if (rhs.data_ == rhs.buffer_) { memcpy(this->buffer_, rhs.buffer_, sizeof(rhs.buffer_)); this->data_ = (void*)(this->buffer_); } else { this->data_ = rhs.data_; } return *this; } #define ARG_PTR_DEFINE(Type, Name) \ Type* Name##Ptr() const { \ TORCH_INTERNAL_ASSERT(data_ == (void*)buffer_); \ return (Type*)data_; \ } AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, ARG_PTR_DEFINE); #undef ARG_PTR_DEFINE private: void* data_; // Regarding a scalar value, CallArg uses void**=&data_ to store it. But the // bit width of a pointer is 32bit on a 32bit platform. It cannot store the // scalar if the bit width of the scalar is larger than 32bit, such as double // and long. Hence, we add 8 bytes buffer dedicated to storing the scalar // value regardless its bit width is less or greater than 32bits. char buffer_[8] = {0}; // 64bits }; class RegisterCodeGenList { public: TORCH_API static RegisterCodeGenList& GetInstance(); using StmtFactoryMethod = std::function( StmtPtr stmt, const std::vector&, at::Device device, const std::string& kernel_func_name)>; TORCH_API StmtFactoryMethod FindStmtFactoryMethod(const std::string& name); RegisterCodeGenList(const RegisterCodeGenList&) = delete; RegisterCodeGenList& operator=(const RegisterCodeGenList&) = delete; private: template friend class RegisterCodeGen; RegisterCodeGenList() = default; TORCH_API void AddStmtFactoryMethod( const std::string& name, const StmtFactoryMethod& stmt_factory_method); std::unordered_map stmt_factory_methods_; }; template class RegisterCodeGen { public: explicit RegisterCodeGen(const std::string& name) { RegisterCodeGenList& codegen_list = RegisterCodeGenList::GetInstance(); codegen_list.AddStmtFactoryMethod( name, [](StmtPtr stmt, const std::vector& params, at::Device device, const std::string& kernel_func_name) { std::unique_ptr method( new CodeGenType(stmt, params, device, kernel_func_name)); return method; }); } }; TORCH_API std::unique_ptr CreateCodeGen( const std::string& name, StmtPtr stmt, const std::vector& params, at::Device device = at::kCPU, const std::string& kernel_func_name = "func"); class TORCH_API GenericIntrinsicsExpander : public IRMutator { protected: ExprPtr mutate(const IntrinsicsPtr& v) override; }; } // namespace torch::jit::tensorexpr