#ifdef TORCH_ENABLE_LLVM #include #include #include #include #include #include #include using namespace torch::indexing; namespace te = torch::jit::tensorexpr; static void vectorize(te::LoopNest* ln, te::Tensor target, int width) { auto loops = ln->getLoopStmtsFor(target); te::ForPtr inner, tail; ln->splitWithTail(loops[0], width, &inner, &tail); ASSERT_TRUE(te::LoopNest::vectorize(inner)); } std::string diffs(const at::Tensor& a, const at::Tensor& b) { auto diff = torch::abs(a.flatten() - b.flatten()); auto count_diffs = torch::sum(diff > 0.f); auto greatest_diff_index = torch::argmax(diff); std::stringstream ss; ss << "Found " << count_diffs << " unequal element(s). " << "The greatest difference was " << diff.index({greatest_diff_index}) << " at index " << greatest_diff_index; return ss.str(); } TEST(Approx, log_vml) { te::VarHandle N("N", te::kInt); te::BufHandle A("A", {N}, te::kFloat); te::Tensor B = te::Compute( "B", {N}, [&](const te::VarHandle& i) { return log_vml(A.load(i)); }); te::LoopNest ln({B}); ln.prepareForCodegen(); vectorize(&ln, B, 8); te::StmtPtr s = ln.root_stmt(); s = te::IRSimplifier::simplify(s); te::LLVMCodeGen cg(s, {A, B, N}); auto eps = std::numeric_limits::epsilon(); auto test = [&](const at::Tensor& A_t) { at::Tensor B_ref = at::log(A_t); at::Tensor B_t = at::empty_like(A_t); auto ap = A_t.data_ptr(); auto bp = B_t.data_ptr(); cg.call({ap, bp, A_t.numel()}); // Results should be bit-identical. ASSERT_TRUE(torch::allclose( B_t, B_ref, /*rtol=*/eps, /*atol=*/0.0f, /*equal_nan=*/true)) << "Input[:8]\n" << A_t.index({Slice(0, 8)}) << "\n" << "Test[:8]\n" << B_t.index({Slice(0, 8)}) << "\n" << "Ref[:8]\n" << B_ref.index({Slice(0, 8)}) << diffs(B_t, B_ref); }; // Generate every single-precision FP value in [1.0, 2.0). at::Tensor A_t = torch::arange(1.0f, 2.0f, eps); ASSERT_EQ(A_t.numel(), 1 << 23); test(A_t); test(A_t * 2.0f); test(A_t * 0.5f); test(A_t * 4.0f); test(A_t * 0.25f); test(A_t * powf(2.0f, 16)); test(A_t * powf(2.0f, -16)); test(A_t * powf(2.0f, 126)); test(A_t * powf(2.0f, -126)); test(torch::full({32}, INFINITY)); test(torch::full({32}, NAN)); auto min = std::numeric_limits::min(); auto denorm_min = std::numeric_limits::denorm_min(); // Denormals aren't bit precise, because sleef isn't bit-precise either. A_t = torch::arange(0.0f, min, denorm_min); ASSERT_EQ(A_t.numel(), 1 << 23); auto B_ref = at::log(A_t); auto B_t = at::empty_like(B_ref); cg.call({A_t.data_ptr(), B_t.data_ptr(), A_t.numel()}); ASSERT_TRUE(torch::allclose(B_t, B_ref)); } #endif // TORCH_ENABLE_LLVM