#pragma once #include #include namespace torch { namespace jit { namespace tensorexpr { template using NodePtr = std::shared_ptr; template NodePtr to(const NodePtr& x) { return std::dynamic_pointer_cast(x); } template NodePtr static_to(NodePtr x) { return std::static_pointer_cast(x); } template NodePtr alloc(Args&&... args) { return std::make_shared(std::forward(args)...); } class Buf; class Expr; class Stmt; class Var; using BufPtr = NodePtr; using ExprPtr = NodePtr; using StmtPtr = NodePtr; using VarPtr = NodePtr; class ExprHandle; class VarHandle; class BufHandle; class Add; class And; class BitCast; class Broadcast; class Cast; class CompareSelect; class Div; class IfThenElse; class Intrinsics; class Let; class Load; class Lshift; class Max; class MaxTerm; class Min; class MinTerm; class Mod; class Mul; class Or; class Polynomial; class Ramp; class ReduceOp; class RoundOff; class Rshift; class Store; class Sub; class Term; class Xor; using AddPtr = NodePtr; using AndPtr = NodePtr; using BitCastPtr = NodePtr; using BroadcastPtr = NodePtr; using CastPtr = NodePtr; using CompareSelectPtr = NodePtr; using DivPtr = NodePtr
; using IfThenElsePtr = NodePtr; using IntrinsicsPtr = NodePtr; using LetPtr = NodePtr; using LoadPtr = NodePtr; using LshiftPtr = NodePtr; using MaxPtr = NodePtr; using MaxTermPtr = NodePtr; using MinPtr = NodePtr; using MinTermPtr = NodePtr; using ModPtr = NodePtr; using MulPtr = NodePtr; using OrPtr = NodePtr; using PolynomialPtr = NodePtr; using RampPtr = NodePtr; using ReduceOpPtr = NodePtr; using RoundOffPtr = NodePtr; using RshiftPtr = NodePtr; using StorePtr = NodePtr; using SubPtr = NodePtr; using TermPtr = NodePtr; using XorPtr = NodePtr; class Allocate; class AtomicAdd; class Block; class Cond; class ExternalCall; class ExternalCallWithAlloc; class For; class Free; class FreeExt; class PlacementAllocate; class SyncThreads; using AllocatePtr = NodePtr; using AtomicAddPtr = NodePtr; using BlockPtr = NodePtr; using CondPtr = NodePtr; using ExternalCallPtr = NodePtr; using ExternalCallWithAllocPtr = NodePtr; using ForPtr = NodePtr; using FreePtr = NodePtr; using FreeExtPtr = NodePtr; using PlacementAllocatePtr = NodePtr; using SyncThreadsPtr = NodePtr; #define IMM_DECLARE(Type, Name) \ class Name##Imm; \ using Name##ImmPtr = NodePtr; AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_DECLARE); #undef IMM_DECLARE } // namespace tensorexpr } // namespace jit } // namespace torch