/** * This file implements the core classes for Tensor Expressions. * * The structure of the expressions is inspired by Halide/TVM IR. */ #pragma once #include #include #include #include #include #include #include namespace torch::jit::tensorexpr { enum IRNodeType { kPrimitive, kAdd, kSub, kMul, kDiv, kMod, kMax, kMin, kAnd, kOr, kLshift, kRshift, kXor, kCompareSelect, kCast, kBitCast, kOther, }; // The common base between all expression node. class TORCH_API Expr : public std::enable_shared_from_this { public: explicit Expr(Dtype dtype, IRNodeType expr_type = kOther) : dtype_(dtype), expr_type_(expr_type) {} virtual ~Expr() = default; Dtype dtype() const { return dtype_; } virtual void accept(IRVisitor* visitor) = 0; virtual ExprPtr accept_mutator(IRMutator* mutator) = 0; IRNodeType expr_type() const { return expr_type_; } // Is this a fixed (constant) immediate value. virtual bool isConstant() const { return false; } void set_dtype(Dtype dtype) { dtype_ = dtype; } /* * Make a deep copy of the given expression. * * All sub-expressions inside the given expressions are also cloned. Note * that the variables are not deep-copied since they are immutable. */ static ExprPtr clone(const ExprPtr& s); protected: std::shared_ptr getptr() { return shared_from_this(); } private: Dtype dtype_; IRNodeType expr_type_; }; // A CRTP pattern to accept visitors for children class, // and dispatch back to the children. template class ExprNode : public Base { public: using ExprNodeBase = ExprNode; void accept(IRVisitor* visitor) override { visitor->visit(static_to(Base::getptr())); } ExprPtr accept_mutator(IRMutator* mutator) override; // pass the constructor to the base class using Base::Base; }; // A wrapper object to the underlying ExprNode. // Also serves the primary way to build and operate on other expressions. class TORCH_API ExprHandle { public: ExprHandle() = default; explicit ExprHandle(ExprPtr node) : base_expr_node_(std::move(node)) {} ExprPtr node() { return base_expr_node_; } ExprPtr node() const { return base_expr_node_; } bool empty() const { return base_expr_node_ == nullptr; } #define IMM_EXPR_DECLARE(Type, Name) ExprHandle(Type v); AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_EXPR_DECLARE); #undef IMM_EXPR_DECLARE template NodePtr AsNode() { return to(this->node()); } template NodePtr AsNode() const { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) return const_cast(this)->AsNode(); } Dtype dtype() const { return node()->dtype(); } // Handling the math operators. ExprHandle operator+(const ExprHandle& other) const; ExprHandle operator-(const ExprHandle& other) const; ExprHandle operator*(const ExprHandle& other) const; ExprHandle operator/(const ExprHandle& other) const; ExprHandle operator%(const ExprHandle& other) const; ExprHandle operator==(const ExprHandle& other) const; ExprHandle operator!=(const ExprHandle& other) const; ExprHandle operator>(const ExprHandle& other) const; ExprHandle operator>=(const ExprHandle& other) const; ExprHandle operator<(const ExprHandle& other) const; ExprHandle operator<=(const ExprHandle& other) const; ExprHandle operator&(const ExprHandle& other) const; ExprHandle operator|(const ExprHandle& other) const; ExprHandle operator&&(const ExprHandle& other) const; ExprHandle operator||(const ExprHandle& other) const; ExprHandle operator^(const ExprHandle& other) const; ExprHandle operator<<(const ExprHandle& other) const; ExprHandle operator>>(const ExprHandle& other) const; private: ExprPtr base_expr_node_ = nullptr; }; // The underlying representation node to a Var. // Currently, each Var object represents a unique variable, even though the // names might be the same. We should consider add a unique_name as well. class TORCH_API Var : public ExprNode { public: static ExprHandle make(const std::string& name_hint, Dtype dtype) { return ExprHandle(alloc(name_hint, dtype)); } static ExprHandle make(Dtype dtype) { return ExprHandle(alloc("", dtype)); } // TODO: unique_name const std::string& name_hint() const { return name_hint_; } void set_name_hint(const std::string& name) { name_hint_ = name; } void set_name_hint(std::string&& name) { name_hint_ = std::move(name); } Var(std::string name_hint, Dtype dtype) : ExprNodeBase(dtype, kPrimitive), name_hint_(std::move(name_hint)) {} private: std::string name_hint_; }; TORCH_API std::vector make_contiguous_strides( const std::vector& dims); TORCH_API std::vector make_channels_last_strides( const std::vector& dims); class TORCH_API Buf : public ExprNode { public: static BufHandle make(const std::vector& dims, Dtype dtype); static BufHandle make( const std::string& name_hint, const std::vector& dims, const std::vector& strides, Dtype dtype); static BufHandle make( const std::string& name_hint, const std::vector& dims, Dtype dtype, std::optional initializer = std::nullopt, const std::optional>& strides = std::nullopt, std::optional qscale = std::nullopt, std::optional qzero = std::nullopt); // TODO: unique_name VarPtr base_handle() const { return base_handle_; } void set_base_handle(VarPtr base_handle) { base_handle_ = std::move(base_handle); } const std::string& name_hint() const { return base_handle_->name_hint(); } void set_name_hint(const std::string& name_hint) { base_handle_->set_name_hint(name_hint); } Buf(const std::string& name_hint, const std::vector& dims, Dtype dtype, ExprPtr initializer = nullptr, std::optional> strides = std::nullopt, ExprPtr qscale = nullptr, ExprPtr qzero = nullptr) : Buf(alloc(name_hint, kHandle), dims, dtype, std::move(initializer), std::move(strides), std::move(qscale), std::move(qzero)) {} Buf(const VarPtr& var, std::vector dims, Dtype dtype, ExprPtr initializer = nullptr, std::optional> strides = std::nullopt, ExprPtr qscale = nullptr, ExprPtr qzero = nullptr); size_t ndim() const { return dims_.size(); } ExprPtr dim(size_t index) const { if (index >= ndim()) { throw out_of_range_index(); } return dims_[index]; } std::vector dims() const { return dims_; } void set_dims(std::vector dims) { dims_ = std::move(dims); } std::vector strides() const { return strides_; } void set_strides(std::vector strides) { strides_ = std::move(strides); } ExprPtr initializer() const { return initializer_; }; ExprPtr qzero() const { return qzero_; } ExprPtr qscale() const { return qscale_; } void set_qzero(ExprPtr qzero) { qzero_ = std::move(qzero); } void set_qscale(ExprPtr qscale) { qscale_ = std::move(qscale); } bool hasConstantDims() const { for (const auto& d : dims_) { if (!d->isConstant()) { return false; } } return true; } bool is_contiguous( at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) const; // The channels-last 1d can benefit the performance of some operators like // conv1d. But the MemoryFormat enum has not covered this layout yet. Hence, // we abstract a dedicated function to check channels-last 1d contiguous. // // Channels-last 1d: // dims: n c l // strides(nlc): c*l 1 c bool is_channels_last_1d_contiguous() const { if (dims_.size() != 3) { return false; } return is_stride_one(1) && is_cont_with(2, 1) && is_cont_with(0, 2); } private: bool is_cont_with(int cur_dim, int adjacent_dim) const; bool is_stride_one(int cur_dim) const; VarPtr base_handle_; std::vector dims_; std::vector strides_; ExprPtr initializer_; // qscale_ and qzero_ are used only for quantized dtypes Bufs: kQUInt8, kQInt8 ExprPtr qscale_; ExprPtr qzero_; }; class TORCH_API BufHandle : public ExprHandle { public: BufHandle( const std::string& name_hint, const std::vector& dims, Dtype dtype) : ExprHandle(Buf::make(name_hint, dims, dtype)) {} BufHandle( const std::string& name_hint, const std::vector& dims, const std::vector& strides, Dtype dtype) : ExprHandle(Buf::make(name_hint, dims, strides, dtype)) {} BufHandle(const std::vector& dims, Dtype dtype) : ExprHandle(Buf::make("_", dims, dtype)) {} explicit BufHandle(Dtype dtype) : ExprHandle(Buf::make("_", {}, dtype)) {} explicit BufHandle(BufPtr node) : ExprHandle(std::move(node)) {} BufPtr node() const { return static_to(ExprHandle::node()); } BufPtr node() { return static_to(ExprHandle::node()); } template inline ExprHandle load(const Ts&... ts) const; template inline ExprHandle load(const std::vector& args) const; inline ExprHandle load(const std::vector& args) const; StorePtr store(const std::vector& args, const ExprHandle& val) const; bool operator==(const BufHandle& other) const { return this->node() == other.node(); } bool operator!=(const BufHandle& other) const { return !(*this == other); } const std::string& name_hint() const { return this->node()->name_hint(); } bool empty() const { return (this->node() == nullptr); } size_t ndim() const { return node()->ndim(); } std::vector dims() const; ExprHandle dim(size_t index) const { return ExprHandle(node()->dim(index)); } bool is_contiguous( at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) const { return node()->is_contiguous(memory_format); } bool is_channels_last_1d_contiguous() const { return node()->is_channels_last_1d_contiguous(); } }; // An expression to construct the underlying variable node. // Note: do not store any info here, since it is often possible to slice this // object. For example: VarHandle x('x'); ExprHandle x2 = x; class TORCH_API VarHandle : public ExprHandle { public: // Creates an empty VarHandle whose base Var is set to nullptr. VarHandle() : ExprHandle() {} explicit VarHandle(Dtype dtype) : ExprHandle(Var::make(dtype)) {} VarHandle(const std::string& name_hint, Dtype dtype) : ExprHandle(Var::make(name_hint, dtype)) {} explicit VarHandle(VarPtr node) : ExprHandle(std::move(node)) {} VarPtr node() const { return static_to(ExprHandle::node()); } bool operator==(const VarHandle& other) const { return this->node() == other.node(); } bool operator!=(const VarHandle& other) const { return !(*this == other); } const std::string& name_hint() const { return this->node()->name_hint(); } bool empty() const { return (this->node() == nullptr); } }; template ExprPtr ExprNode::accept_mutator(IRMutator* mutator) { return mutator->mutate(static_to(Base::getptr())); } inline bool same_node(const ExprHandle& expr1, const ExprHandle& expr2) { return expr1.AsNode() == expr2.AsNode(); } TORCH_API ExprHandle sin(const ExprHandle& v); TORCH_API ExprHandle cos(const ExprHandle& v); TORCH_API ExprHandle tan(const ExprHandle& v); TORCH_API ExprHandle asin(const ExprHandle& v); TORCH_API ExprHandle acos(const ExprHandle& v); TORCH_API ExprHandle atan(const ExprHandle& v); TORCH_API ExprHandle sinh(const ExprHandle& v); TORCH_API ExprHandle cosh(const ExprHandle& v); TORCH_API ExprHandle tanh(const ExprHandle& v); TORCH_API ExprHandle sigmoid(const ExprHandle& v); TORCH_API ExprHandle exp(const ExprHandle& v); TORCH_API ExprHandle expm1(const ExprHandle& v); TORCH_API ExprHandle abs(const ExprHandle& v); TORCH_API ExprHandle log(const ExprHandle& v); TORCH_API ExprHandle fast_tanh(const ExprHandle& v); TORCH_API ExprHandle fast_sigmoid(const ExprHandle& v); TORCH_API ExprHandle fast_log(const ExprHandle& v); TORCH_API ExprHandle log_vml(const ExprHandle& v); TORCH_API ExprHandle log2(const ExprHandle& v); TORCH_API ExprHandle log10(const ExprHandle& v); TORCH_API ExprHandle log1p(const ExprHandle& v); TORCH_API ExprHandle erf(const ExprHandle& v); TORCH_API ExprHandle erfc(const ExprHandle& v); TORCH_API ExprHandle sqrt(const ExprHandle& v); TORCH_API ExprHandle rsqrt(const ExprHandle& v); TORCH_API ExprHandle ceil(const ExprHandle& v); TORCH_API ExprHandle floor(const ExprHandle& v); TORCH_API ExprHandle round(const ExprHandle& v); TORCH_API ExprHandle trunc(const ExprHandle& v); TORCH_API ExprHandle frac(const ExprHandle& v); TORCH_API ExprHandle lgamma(const ExprHandle& v); TORCH_API ExprHandle atan2(const ExprHandle& v1, const ExprHandle& v2); TORCH_API ExprHandle pow(const ExprHandle& v1, const ExprHandle& v2); TORCH_API ExprHandle fmod(const ExprHandle& v1, const ExprHandle& v2); TORCH_API ExprHandle remainder(const ExprHandle& v1, const ExprHandle& v2); TORCH_API ExprHandle isnan(const ExprHandle& v1); TORCH_API ExprHandle Relu(const ExprHandle& v1); TORCH_API ExprHandle ifThenElse(const ExprHandle& c, const ExprHandle& t, const ExprHandle& f); TORCH_API ExprHandle expr_to_vec(ExprHandle v, int lanes); } // namespace torch::jit::tensorexpr