#pragma once #include #include #include #include #include namespace torch { namespace jit { using namespace torch::jit::tensorexpr; #define IS_NODE(T, node) \ { \ auto node_ = to(node); \ ASSERT_NE(nullptr, node_); \ } #define IS_NODE_WITH_NAME(T, node, name) \ auto name = to(node); \ ASSERT_NE(nullptr, name); #define IS_NODE_WITH_NAME_AND_CAST(T, node, name, Type) \ NodePtr name = nullptr; \ { \ auto node_ = to(node); \ ASSERT_NE(nullptr, node_); \ ASSERT_EQ(node_->dtype().scalar_type(), ScalarType::Type); \ name = to(node_->src_value()); \ } \ ASSERT_NE(nullptr, name); #define IS_IMM_WITH_VAL(T, node, val) \ { \ auto node_ = to(node); \ ASSERT_NE(nullptr, node_); \ ASSERT_EQ(node_->value(), val); \ } #define IS_VAR_WITH_NAME(node, name) \ { \ auto node_ = to(node); \ ASSERT_NE(nullptr, node_); \ ASSERT_EQ(node_->name_hint(), name); \ } #define IS_BINOP_W_VARS(T, node, name, v1, v2) \ NodePtr name = nullptr; \ { \ name = to(node); \ ASSERT_NE(nullptr, name); \ IS_VAR_WITH_NAME(name->lhs(), v1); \ IS_VAR_WITH_NAME(name->rhs(), v2); \ } #define IS_BINOP_W_CONST(T, node, name, v, c) \ NodePtr name = nullptr; \ { \ name = to(node); \ ASSERT_NE(nullptr, name); \ IS_VAR_WITH_NAME(name->lhs(), v); \ IS_IMM_WITH_VAL(Int, name->rhs(), c); \ } #define IS_RAND(node) \ { \ auto node_ = to(node); \ ASSERT_NE(nullptr, node_); \ ASSERT_EQ(node_->op_type(), kRand); \ } void checkIR(StmtPtr s, const std::string& pattern); void checkExprIR(ExprPtr e, const std::string& pattern); void checkExprIR(const ExprHandle& e, const std::string& pattern); } // namespace jit } // namespace torch