#include #include #include #include namespace torch::jit::graph_rewrite_helper { std::string getFuncName(Value* func_value) { auto func = func_value->type()->expectRef().function(); const auto& qname = func->qualname(); const auto& name = qname.qualifiedName(); auto rdot_idx = name.rfind('.'); if (rdot_idx != std::string::npos) { return name.substr(rdot_idx + 1, name.length()); } else { return name; } } Value* getValue( const std::string& name, const std::unordered_map& match_vmap, const std::unordered_map& vmap) { return match_vmap.at(vmap.at(name)); } std::optional getIValue( const std::string& name, const std::unordered_map& match_vmap, const std::unordered_map& vmap) { return toIValue(getValue(name, match_vmap, vmap)); } static std::unordered_map getConvParams( const Match& match, const std::unordered_map& vmap) { std::unordered_map calc_values; const auto& match_vmap = match.values_map; auto transposed_value = getIValue("transposed", match_vmap, vmap).value(); calc_values["transposed"] = transposed_value; auto output_padding_value = getIValue("output_padding", match_vmap, vmap).value(); calc_values["output_padding"] = output_padding_value; auto stride_value = getIValue("stride", match_vmap, vmap).value(); calc_values["stride"] = stride_value; auto padding_value = getIValue("padding", match_vmap, vmap).value(); calc_values["padding"] = padding_value; auto dilation_value = getIValue("dilation", match_vmap, vmap).value(); calc_values["dilation"] = dilation_value; return calc_values; } void replaceConvolutionWithAtenConv(std::shared_ptr& graph) { // TODO: remove constant prop in the pass ConstantPropagation(graph); std::string convolution_deprecated = R"( graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, %deterministic:bool, %cudnn_enabled:bool): %r = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation, %transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled) return (%r) )"; std::string convolution = R"( graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool): %r = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation, %transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled, %allow_tf32) return (%r) )"; std::string conv2d_for_deprecated_conv = R"( graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, %deterministic:bool, %cudnn_enabled:bool): %r = aten::conv2d(%a, %w, %b, %stride, %padding, %dilation, %groups) return (%r) )"; std::string conv2d = R"( graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool): %r = aten::conv2d(%a, %w, %b, %stride, %padding, %dilation, %groups) return (%r) )"; std::string conv1d_for_deprecated_conv = R"( graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, %deterministic:bool, %cudnn_enabled:bool): %r = aten::conv1d(%a, %w, %b, %stride, %padding, %dilation, %groups) return (%r) )"; std::string conv1d = R"( graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool): %r = aten::conv1d(%a, %w, %b, %stride, %padding, %dilation, %groups) return (%r) )"; std::string conv3d_for_deprecated_conv = R"( graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, %deterministic:bool, %cudnn_enabled:bool): %r = aten::conv3d(%a, %w, %b, %stride, %padding, %dilation, %groups) return (%r) )"; std::string conv3d = R"( graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool): %r = aten::conv3d(%a, %w, %b, %stride, %padding, %dilation, %groups) return (%r) )"; std::string conv_transpose1d_for_deprecated_conv = R"( graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, %deterministic:bool, %cudnn_enabled:bool): %r = aten::conv_transpose1d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation) return (%r) )"; std::string conv_transpose1d = R"( graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool): %r = aten::conv_transpose1d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation) return (%r) )"; std::string conv_transpose2d_for_deprecated_conv = R"( graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, %deterministic:bool, %cudnn_enabled:bool): %r = aten::conv_transpose2d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation) return (%r) )"; std::string conv_transpose2d = R"( graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool): %r = aten::conv_transpose2d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation) return (%r) )"; std::string conv_transpose3d_for_deprecated_conv = R"( graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, %deterministic:bool, %cudnn_enabled:bool): %r = aten::conv_transpose3d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation) return (%r) )"; std::string conv_transpose3d = R"( graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool): %r = aten::conv_transpose3d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation) return (%r) )"; // Filter the unsupported case auto filter_conv1d = [](const Match& match, const std::unordered_map& vmap) { auto calc_value_map = getConvParams(match, vmap); if (calc_value_map["output_padding"].toIntList().size() != 1 || calc_value_map["stride"].toIntList().size() != 1 || calc_value_map["padding"].toIntList().size() != 1 || calc_value_map["dilation"].toIntList().size() != 1) { return false; } return !calc_value_map["transposed"].toBool(); }; auto filter_conv2d = [](const Match& match, const std::unordered_map& vmap) { auto calc_value_map = getConvParams(match, vmap); if (calc_value_map["output_padding"].toIntList().size() != 2 || calc_value_map["stride"].toIntList().size() != 2 || calc_value_map["padding"].toIntList().size() != 2 || calc_value_map["dilation"].toIntList().size() != 2) { return false; } return !calc_value_map["transposed"].toBool(); }; auto filter_conv3d = [](const Match& match, const std::unordered_map& vmap) { auto calc_value_map = getConvParams(match, vmap); if (calc_value_map["output_padding"].toIntList().size() != 3 || calc_value_map["stride"].toIntList().size() != 3 || calc_value_map["padding"].toIntList().size() != 3 || calc_value_map["dilation"].toIntList().size() != 3) { return false; } return !calc_value_map["transposed"].toBool(); }; auto filter_conv_transpose1d = [](const Match& match, const std::unordered_map& vmap) { auto calc_value_map = getConvParams(match, vmap); if (calc_value_map["output_padding"].toIntList().size() != 1 || calc_value_map["stride"].toIntList().size() != 1 || calc_value_map["padding"].toIntList().size() != 1 || calc_value_map["dilation"].toIntList().size() != 1) { return false; } return calc_value_map["transposed"].toBool(); }; auto filter_conv_transpose2d = [](const Match& match, const std::unordered_map& vmap) { auto calc_value_map = getConvParams(match, vmap); if (calc_value_map["output_padding"].toIntList().size() != 2 || calc_value_map["stride"].toIntList().size() != 2 || calc_value_map["padding"].toIntList().size() != 2 || calc_value_map["dilation"].toIntList().size() != 2) { return false; } return calc_value_map["transposed"].toBool(); }; auto filter_conv_transpose3d = [](const Match& match, const std::unordered_map& vmap) { auto calc_value_map = getConvParams(match, vmap); if (calc_value_map["output_padding"].toIntList().size() != 3 || calc_value_map["stride"].toIntList().size() != 3 || calc_value_map["padding"].toIntList().size() != 3 || calc_value_map["dilation"].toIntList().size() != 3) { return false; } return calc_value_map["transposed"].toBool(); }; SubgraphRewriter rewriter_conv1d; rewriter_conv1d.RegisterRewritePattern(convolution, conv1d); rewriter_conv1d.RegisterRewritePattern( convolution_deprecated, conv1d_for_deprecated_conv); rewriter_conv1d.runOnGraph(graph, filter_conv1d); SubgraphRewriter rewriter_conv2d; rewriter_conv2d.RegisterRewritePattern(convolution, conv2d); rewriter_conv2d.RegisterRewritePattern( convolution_deprecated, conv2d_for_deprecated_conv); rewriter_conv2d.runOnGraph(graph, filter_conv2d); SubgraphRewriter rewriter_conv3d; rewriter_conv3d.RegisterRewritePattern(convolution, conv3d); rewriter_conv3d.RegisterRewritePattern( convolution_deprecated, conv3d_for_deprecated_conv); rewriter_conv3d.runOnGraph(graph, filter_conv3d); SubgraphRewriter rewriter_conv_transpose1d; rewriter_conv_transpose1d.RegisterRewritePattern( convolution, conv_transpose1d); rewriter_conv_transpose1d.RegisterRewritePattern( convolution_deprecated, conv_transpose1d_for_deprecated_conv); rewriter_conv_transpose1d.runOnGraph(graph, filter_conv_transpose1d); SubgraphRewriter rewriter_conv_transpose2d; rewriter_conv_transpose2d.RegisterRewritePattern( convolution, conv_transpose2d); rewriter_conv_transpose2d.RegisterRewritePattern( convolution_deprecated, conv_transpose2d_for_deprecated_conv); rewriter_conv_transpose2d.runOnGraph(graph, filter_conv_transpose2d); SubgraphRewriter rewriter_conv_transpose3d; rewriter_conv_transpose3d.RegisterRewritePattern( convolution, conv_transpose3d); rewriter_conv_transpose3d.RegisterRewritePattern( convolution_deprecated, conv_transpose3d_for_deprecated_conv); rewriter_conv_transpose3d.runOnGraph(graph, filter_conv_transpose3d); } bool isClampFusable( const Match& match, const std::unordered_map& vmap) { const auto& match_vmap = match.values_map; TORCH_CHECK( vmap.find("dummy_min_max") != vmap.end(), "Expected to find dummy_min_max Value in the subgraph to be replaced."); auto dummy_min_max = graph_rewrite_helper::getIValue("dummy_min_max", match_vmap, vmap); auto is_fusable = !dummy_min_max || dummy_min_max.value().isNone(); // Also check if the output_min and output_max values are actually constant. // If hardtanh's min/max Value's are not actually constants, we will end up // rerouting those values to prepack op. And if they are not constants // we will not be able to remove prepacking ops. if (vmap.find("output_min") != vmap.end()) { // aten::relu pattern does not have output_min/output_max. // aten::hardtanh/_ does. TORCH_CHECK( vmap.find("output_max") != vmap.end(), "Expected to find output_max as well given " "output_min exist in pattern graph."); // If output_min/max are not constant, we get std::nullopt. auto output_min = graph_rewrite_helper::getIValue("output_min", match_vmap, vmap); auto output_max = graph_rewrite_helper::getIValue("output_max", match_vmap, vmap); is_fusable = is_fusable && (output_min.has_value() && output_max.has_value()); } return is_fusable; } } // namespace torch::jit::graph_rewrite_helper