#include #include #include #include #include namespace torch { namespace jit { using namespace testing; TEST(SubgraphRewriterTest, FilterMatch) { auto graph = std::make_shared(); parseIR( R"IR( graph(%0): %a = a::aaa(%0) %b : int = prim::Constant[value=1]() %c = c::ccc(%a, %b) return (%c))IR", graph.get()); std::string pattern = R"IR( graph(%a, %b): %c = c::ccc(%a, %b) return (%c))IR"; Graph pattern_graph; std::unordered_map vmap; parseIR(pattern, &pattern_graph, vmap); auto b_is_constant = [](const Match& match, const std::unordered_map& vmap) { const auto& match_vmap = match.values_map; auto b_node = match_vmap.at(vmap.at("b"))->node(); return b_node->kind() == prim::Constant; }; auto b_is_one = [](const Match& match, const std::unordered_map& vmap) { const auto& match_vmap = match.values_map; auto b_val = toIValue(match_vmap.at(vmap.at("b"))); return b_val && b_val->isInt() && b_val->toInt() == 1; }; auto b_is_two = [](const Match& match, const std::unordered_map& vmap) { const auto& match_vmap = match.values_map; auto b_val = toIValue(match_vmap.at(vmap.at("b"))); return b_val && b_val->isInt() && b_val->toInt() == 2; }; std::string replacement = R"IR( graph(%a, %b): %d = d::ddd(%a, %b) return (%d))IR"; SubgraphRewriter rewriter; rewriter.RegisterRewritePattern(pattern, replacement); // b is constant, so the match will succeed { auto g = graph->copy(); rewriter.runOnGraph(g, b_is_constant); FileCheck().check("d::ddd")->check_not("c::ccc")->run(*g); } // b is constant and the value is one, the match will succeed { auto g = graph->copy(); rewriter.runOnGraph(g, {b_is_constant, b_is_one}); FileCheck().check("d::ddd")->check_not("c::ccc")->run(*g); } // b is constant but the value is not two, the match will fail { auto g = graph->copy(); rewriter.runOnGraph(g, {b_is_constant, b_is_two}); FileCheck().check("c::ccc")->check_not("d::ddd")->run(*g); } } TEST(SubgraphRewriterTest, FilterNoMatch) { auto graph = std::make_shared(); parseIR( R"IR( graph(%0): %a = a::aaa(%0) %b = prim::Constant[value=1]() %c = c::ccc(%a, %b) return (%c))IR", graph.get()); std::string pattern = R"IR( graph(%a, %b): %c = c::ccc(%a, %b) return (%c))IR"; Graph pattern_graph; std::unordered_map vmap; parseIR(pattern, &pattern_graph, vmap); auto filter = [](const Match& match, const std::unordered_map& vmap) { const auto& match_vmap = match.values_map; auto b_node = match_vmap.at(vmap.at("b"))->node(); // b_node is not prim::Assign, so this won't match and we'll skip the // rewrite return b_node->kind() == prim::Assign; }; std::string replacement = R"IR( graph(%a, %b): %d = d::ddd(%a, %b) return (%d))IR"; SubgraphRewriter rewriter; rewriter.RegisterRewritePattern(pattern, replacement); rewriter.runOnGraph(graph, filter); FileCheck().check("c::ccc")->check_not("d::ddd")->run(*graph); } TEST(SubgraphRewriterTest, MultiOutput) { { auto graph = std::make_shared(); // Basic multi-output pattern rewriting parseIR( R"IR( graph(%0, %1): %a1, %a2 = a::aaa(%0, %1) %b = b::bbb(%a1) %c = c::ccc(%b) %x1, %x2 = a::aaa(%c, %a2) %y = b::bbb(%x1) %z = d::ddd(%y) return (%z))IR", graph.get()); std::string pattern = R"IR( graph(%0, %1): %a1, %a2 = a::aaa(%0, %1) %b = b::bbb(%a1) return (%b, %a2))IR"; std::string replacement = R"IR( graph(%a, %b): %x, %y = ab::ababab(%a, %b) return (%x, %y))IR"; SubgraphRewriter rewriter; rewriter.RegisterRewritePattern(pattern, replacement); auto g = graph->copy(); rewriter.runOnGraph(g); FileCheck().check("ab::ababab")->check("ab::ababab")->run(*g); } { auto graph = std::make_shared(); // Mimic a real model case parseIR( R"IR( graph(%k, %m, %x1, %x2, %x3, %x4, %y1, %y2, %y3, %y4): %a1 = aa::aaa(%x1, %k) %b1_1, %b1_2 = bb::bbb(%y1, %a1) %a2 = aa::aaa(%x2, %k) %b2_1, %b2_2 = bb::bbb(%y2, %a2) %a3 = aa::aaa(%x3, %k) %b3_1, %b3_2 = bb::bbb(%y3, %a3) %a4 = aa::aaa(%x4, %k) %b4_1, %b4_2 = bb::bbb(%y4, %a4) %c = cc::ccc(%b4_1) %d1 = dd::ddd(%b1_2, %m) %e1 = ee::eee(%b1_1, %d1) %d2 = dd::ddd(%b2_2, %m) %e2 = ee::eee(%b2_1, %d2) %d3 = dd::ddd(%b3_2, %m) %e3 = ee::eee(%b3_1, %d3) %d4 = dd::ddd(%b4_2, %m) %e4 = ee::eee(%b4_1, %d4) return (%d1, %d2, %d3, %d4, %e1, %e2, %e3, %e4) )IR", graph.get()); std::string pattern = R"IR( graph(%a, %b, %c, %d): %y0 = aa::aaa(%b, %c) %y1, %y2 = bb::bbb(%a, %y0) %y3 = dd::ddd(%y2, %d) return (%y3, %y1))IR"; std::string replacement = R"IR( graph(%a, %b, %c, %d): %x, %y = ab::ababab(%a, %b, %c, %d) return (%x, %y))IR"; SubgraphRewriter rewriter; rewriter.RegisterRewritePattern(pattern, replacement); auto g = graph->copy(); rewriter.runOnGraph(g); FileCheck().check("ab::ababab")->check("ab::ababab")->run(*g); } { auto graph = std::make_shared(); // A case where no rewriting should occur due to data dependencies parseIR( R"IR( graph(%x, %y): %a = aa::aaa(%x) %b = bb::bbb(%a) %e = ee::eee(%b) %c = cc::ccc(%y) %d = dd::ddd(%b, %c) %f = ff::fff(%b, %d) return (%f) )IR", graph.get()); std::string pattern = R"IR( graph(%a, %c): %b = bb::bbb(%a) %d = dd::ddd(%b, %c) return (%d, %b))IR"; std::string replacement = R"IR( graph(%a, %c): %d, %b = db::fused(%a, %c) return (%d, %b))IR"; SubgraphRewriter rewriter; rewriter.RegisterRewritePattern(pattern, replacement); auto g = graph->copy(); rewriter.runOnGraph(g); // We should not perform the replacement on the given graph due to data // dependency constraints: the output %b is used in %e, which precedes one // def of the input %c. FileCheck().check_not("db::fused")->run(*g); } } TEST(SubgraphRewriterTest, OutputType) { std::string pattern = R"IR( graph(%a, %b): %c = c::ccc(%a, %b) return (%c))IR"; Graph pattern_graph; std::unordered_map vmap; parseIR(pattern, &pattern_graph, vmap); auto b_is_constant = [](const Match& match, const std::unordered_map& vmap) { const auto& match_vmap = match.values_map; auto b_node = match_vmap.at(vmap.at("b"))->node(); return b_node->kind() == prim::Constant; }; std::string replacement = R"IR( graph(%a, %b): %d = d::ddd(%a, %b) return (%d))IR"; SubgraphRewriter rewriter; rewriter.RegisterRewritePattern(pattern, replacement); { auto graph = std::make_shared(); parseIR( R"IR( graph(%0): %a : Float(10, 20) = a::aaa(%0) %b : int = prim::Constant[value=1]() %c : Float(10, 20) = c::ccc(%a, %b) return (%c))IR", graph.get()); // output has shape info. rewriter.runOnGraph(graph, b_is_constant); FileCheck() .check("Float(10, 20) = d::ddd") ->check_not("c::ccc") ->run(*graph); } { auto graph = std::make_shared(); parseIR( R"IR( graph(%0): %a = a::aaa(%0) %b : int = prim::Constant[value=1]() %c = c::ccc(%a, %b) return (%c))IR", graph.get()); // output has not shape info. rewriter.runOnGraph(graph, b_is_constant); FileCheck().check("Tensor = d::ddd")->check_not("c::ccc")->run(*graph); } } } // namespace jit } // namespace torch