#include #include "test/cpp/tensorexpr/test_base.h" #include #include #include #include #include #include namespace torch { namespace jit { using namespace torch::jit::tensorexpr; #define STR_CHECK(node, expected) \ std::stringstream ss; \ CppPrinter printer(&ss); \ printer.visit(node); \ ASSERT_EQ(ss.str(), expected) #define FILE_CHECK(node, pattern) \ std::stringstream ss; \ CppPrinter printer(&ss); \ printer.visit(node); \ torch::jit::testing::FileCheck().run(pattern, ss.str()) TEST(CppPrinter, IntImm) { auto i = alloc(10); STR_CHECK(i, "10"); } TEST(CppPrinter, FloatImm) { auto f = alloc(10); STR_CHECK(f, "10.f"); } TEST(CppPrinter, FloatImm1) { auto f = alloc(10); STR_CHECK(f, "10.f"); } TEST(CppPrinter, DoubleImm) { auto d = alloc(10); STR_CHECK(d, "10.0"); } TEST(CppPrinter, DoubleImm1) { auto d = alloc(10.1); STR_CHECK(d, "10.1"); } TEST(CppPrinter, HalfImm) { auto h = alloc(10); STR_CHECK(h, "10"); } TEST(CppPrinter, Add) { auto add = alloc(alloc(1), alloc(2)); STR_CHECK(add, "1 + 2"); } TEST(CppPrinter, AddExpr1) { auto add = alloc( alloc(alloc(0), alloc(1)), alloc(alloc(2), alloc(3))); STR_CHECK(add, "(0 + 1) + (2 - 3)"); } TEST(CppPrinter, AddExpr2) { auto add = alloc( alloc(alloc(0), alloc(1)), alloc(alloc(2), alloc(3))); STR_CHECK(add, "0 * 1 + (2 - 3)"); } TEST(CppPrinter, AddExpr3) { auto add = alloc( alloc(alloc(0), alloc(1)), alloc
(alloc(2), alloc(3))); STR_CHECK(add, "(0 + 1) + 2 / 3"); } TEST(CppPrinter, Mod) { auto mod = alloc(alloc(1), alloc(2)); STR_CHECK(mod, "1 % 2"); } TEST(CppPrinter, ModFloat) { auto mod = alloc(alloc(1), alloc(2)); STR_CHECK(mod, "std::fmod(1.f, 2.f)"); } TEST(CppPrinter, Max) { auto max = alloc(alloc(1), alloc(2), false); STR_CHECK(max, "std::max(1, 2)"); } TEST(CppPrinter, MaxFloat) { auto max = alloc(alloc(1), alloc(2), false); STR_CHECK(max, "std::max(1.f, 2.f)"); } TEST(CppPrinter, MaxHalf) { auto max = alloc(alloc(1), alloc(2), false); STR_CHECK(max, "(1 < 2) ? 2 : 1"); } TEST(CppPrinter, And) { auto v = alloc(alloc(1), alloc(2)); STR_CHECK(v, "1 & 2"); } TEST(CppPrinter, CompareSelect) { auto cs = alloc( alloc(1), alloc(2), alloc(1), alloc(2), CompareSelectOperation::kLE); STR_CHECK(cs, "((1 <= 2) ? 1.f : 2.f)"); } TEST(CppPrinter, IfThenElse) { auto cond = alloc(alloc(1), alloc(2)); auto true_value = alloc(alloc(0), alloc(1)); auto false_value = alloc(alloc(2), alloc(3)); auto v = alloc(cond, true_value, false_value); STR_CHECK(v, "((1 + 2) ? 0 - 1 : 2 * 3)"); } TEST(CppPrinter, AllocateFree) { BufHandle buf("x", {2, 3}, kInt); AllocatePtr alloc = Allocate::make(buf); FreePtr free = Free::make(buf); BlockPtr block = Block::make({alloc, free}); const std::string pattern = R"( # CHECK: { # CHECK: int* x = static_cast(malloc(24)); # CHECK: free(x); # CHECK: } )"; FILE_CHECK(block, pattern); } TEST(CppPrinter, LoadStore) { BufHandle a("A", {2, 3}, kInt); BufHandle b("B", {3, 4}, kInt); auto store = b.store({2, 2}, a.load(1, 1)); STR_CHECK( store, "B[(0 + 2 * (1 * 4)) + 2 * 1] = A[(0 + 1 * (1 * 3)) + 1 * 1];\n"); } TEST(CppPrinter, Var) { auto var = alloc("x", kInt); STR_CHECK(var, "x"); } TEST(CppPrinter, Cast) { auto cast = alloc(kFloat, alloc(1)); STR_CHECK(cast, "static_cast(1)"); } TEST(CppPrinter, BitCast) { auto cast = alloc(kInt, alloc(20)); STR_CHECK(cast, "std::bitcast(20.f)"); } TEST(CppPrinter, Let) { auto var = alloc("x", kFloat); auto val = alloc(2); auto let = alloc(var, val); STR_CHECK(let, "float x = 2.f;\n"); } TEST(CppPrinter, For) { constexpr int N = 1024; BufHandle a("A", {N}, kInt); BufHandle b("B", {N}, kInt); BufHandle c("C", {N}, kInt); VarHandle i("i", kInt); auto f = For::make(i, 0, N, c.store({i}, Add::make(a.load(i), b.load(i)))); const std::string pattern = R"( # CHECK: for (int i = 0; i < 1024; i++) { # CHECK: C[i] = (A[i]) + (B[i]); # CHECK: } )"; FILE_CHECK(f, pattern); } TEST(CppPrinter, Cond) { BufHandle x("X", {1}, kInt); auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); auto cond = Cond::make(cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1)); const std::string pattern = R"( # CHECK: if (((X[0] < 10) ? 1 : 0)) { # CHECK: X[0] = (X[0]) + 1; # CHECK: } else { # CHECK: X[0] = (X[0]) - 1; # CHECK: } )"; FILE_CHECK(cond, pattern); } TEST(CppPrinter, Intrinsics) { const std::unordered_set> unsupported_ops{ kRand, kSigmoid}; for (const auto i : c10::irange(static_cast(kMaxIntrinsicsOp))) { IntrinsicsOp op = static_cast(i); if (unsupported_ops.count(op)) { continue; } if (Intrinsics::OpArgCount(op) == 1) { auto v = alloc(op, alloc(2.0f)); STR_CHECK(v, "std::" + v->func_name() + "(2.f)"); } else { auto v = alloc(op, alloc(1.0f), alloc(2.0f)); STR_CHECK(v, "std::" + v->func_name() + "(1.f, 2.f)"); } } } TEST(CppPrinter, ExternalCall) { std::vector dims{alloc(2), alloc(2)}; auto output = alloc("out", dims, kFloat); auto buf_arg1 = alloc("a", dims, kFloat); auto buf_arg2 = alloc("b", dims, kFloat); auto scalar_arg = alloc(alloc(1), alloc(2)); std::vector buf_args{buf_arg1, buf_arg2}; std::vector scalar_args{scalar_arg}; auto call = alloc(output, "nnc_aten_matmul", buf_args, scalar_args); const std::string pattern = R"( # CHECK: { # CHECK: void* buf_ptrs[]{out, a, b}; # CHECK: int64_t buf_ranks[]{2, 2, 2}; # CHECK: int64_t buf_dims[]{2, 2, 2, 2, 2, 2}; # CHECK: int8_t buf_dtypes[]{6, 6, 6}; # CHECK: int64_t extra_args[]{1 + 2}; # CHECK: nnc_aten_matmul( # CHECK: 3, # CHECK: buf_ptrs, # CHECK: buf_ranks, # CHECK: buf_dims, # CHECK: buf_dtypes, # CHECK: 1, # CHECK: extra_args); # CHECK: } )"; FILE_CHECK(call, pattern); } } // namespace jit } // namespace torch