#include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch::jit { namespace { void insertPrePackedBatchNormOp(std::shared_ptr& graph) { std::string batchnorm_pattern = R"( graph(%input, %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable): %r = aten::batch_norm(%input, %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable) return (%r))"; std::string prepacked_ops_pattern = R"( graph(%input, %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable): %op_context : __torch__.torch.classes.vulkan.BatchNormPackedContext = vulkan_prepack::create_batchnorm_context( %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable) %res = vulkan_prepack::run_batchnorm_context(%input, %op_context) return (%res))"; SubgraphRewriter batchnorm_rewriter; batchnorm_rewriter.RegisterRewritePattern( batchnorm_pattern, prepacked_ops_pattern); batchnorm_rewriter.runOnGraph(graph); } void insertPrePackedLinearOp(std::shared_ptr& graph) { // fuse decomposed linear into aten::linear FuseLinear(graph); std::string linear_pattern = R"( graph(%input, %weight, %bias): %r = aten::linear(%input, %weight, %bias) return (%r))"; std::string prepacked_ops_pattern = R"( graph(%input, %weight, %bias): %weight_t = aten::t(%weight) %packed_weight_bias = vulkan_prepack::create_linear_context( %weight_t, %bias) %res = vulkan_prepack::run_linear_context(%input, %packed_weight_bias) return (%res))"; SubgraphRewriter linear_rewriter; linear_rewriter.RegisterRewritePattern(linear_pattern, prepacked_ops_pattern); linear_rewriter.runOnGraph(graph); } void insertPrePackedLayernormOp(std::shared_ptr& graph) { std::string layernorm_pattern = R"( graph(%input, %normalized_shape, %weight, %bias, %eps, %cudnn_enable): %r = aten::layer_norm(%input, %normalized_shape, %weight, %bias, %eps, %cudnn_enable) return (%r))"; std::string prepacked_ops_pattern = R"( graph(%input, %normalized_shape, %weight, %bias, %eps, %cudnn_enable): %op_context : __torch__.torch.classes.vulkan.LayernormPackedContext = vulkan_prepack::create_layernorm_context( %weight, %bias, %eps) %res = vulkan_prepack::run_layernorm_context(%input, %normalized_shape, %op_context) return (%res))"; SubgraphRewriter layernorm_rewriter; layernorm_rewriter.RegisterRewritePattern( layernorm_pattern, prepacked_ops_pattern); layernorm_rewriter.runOnGraph(graph); } void insertPrePackedConv2dOp(std::shared_ptr& graph) { graph_rewrite_helper::replaceConvolutionWithAtenConv(graph); std::string conv_2d_pattern = R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int): %r = aten::conv2d(%input, %weight, %bias, %stride, %padding, %dilation, %groups) return (%r) )"; std::string prepacked_ops_conv2d_pattern = R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int): %output_min_max : None = prim::Constant() %packed_weight_bias = vulkan_prepack::create_conv2d_context( %weight, %bias, %stride, %padding, %dilation, %groups, %output_min_max, %output_min_max) %r = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias) return (%r) )"; SubgraphRewriter rewriter; rewriter.RegisterRewritePattern( conv_2d_pattern, prepacked_ops_conv2d_pattern); rewriter.runOnGraph(graph); std::string conv_2d_transpose_pattern = R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %output_padding:int[], %groups:int): %res = aten::conv_transpose2d(%input, %weight, %bias, %stride, %padding, %output_padding, %groups, %dilation) return (%res) )"; std::string prepacked_ops_conv2d_transpose_pattern = R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %output_padding:int[], %groups:int): %output_min_max : None = prim::Constant() %packed_weight_bias = vulkan_prepack::create_tconv2d_context( %weight, %bias, %stride, %padding, %output_padding, %dilation, %groups, %output_min_max, %output_min_max) %res = vulkan_prepack::run_tconv2d_context(%input, %packed_weight_bias) return (%res) )"; SubgraphRewriter transpose_rewriter; transpose_rewriter.RegisterRewritePattern( conv_2d_transpose_pattern, prepacked_ops_conv2d_transpose_pattern); transpose_rewriter.runOnGraph(graph); } void insertPrePackedConv1dOp(std::shared_ptr& graph) { graph_rewrite_helper::replaceConvolutionWithAtenConv(graph); std::string conv_1d_pattern = R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int): %r = aten::conv1d(%input, %weight, %bias, %stride, %padding, %dilation, %groups) return (%r) )"; std::string prepacked_ops_conv1d_pattern = R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int): %packed_weight_bias = vulkan_prepack::create_conv1d_context( %weight, %bias, %stride, %padding, %dilation, %groups) %r = vulkan_prepack::run_conv1d_context(%input, %packed_weight_bias) return (%r) )"; SubgraphRewriter rewriter; rewriter.RegisterRewritePattern( conv_1d_pattern, prepacked_ops_conv1d_pattern); rewriter.runOnGraph(graph); } void transferInputOutputBackends(std::shared_ptr& graph) { // Move inputs to Vulkan backend for (Value* input : graph->inputs()) { NamedValue named_input = NamedValue("", input); if (named_input.type()->kind() == TypeKind::TensorType && !input->uses().empty()) { // find the insertion point WithInsertPoint ip(input->uses()[0].user->prev()); Value* replaced_input = graph->insert( Symbol::fromQualString("aten::to"), {named_input, "vulkan"}); // replace the input input->replaceAllUsesAfterNodeWith( replaced_input->node(), replaced_input); } } // Move outputs to CPU backend at::ArrayRef&& outputs = graph->outputs(); for (size_t i = 0; i < outputs.size(); i++) { Value* output = outputs[i]; NamedValue named_output = NamedValue("", output); if (named_output.type()->kind() == TypeKind::TensorType) { // find the insertion point WithInsertPoint ip(output->node()->next()); Value* replaced_output = graph->insert( Symbol::fromQualString("aten::to"), {named_output, "cpu"}); // replace the output graph->block()->replaceOutput(i, replaced_output); } } SubgraphRewriter rewriter; rewriter.runOnGraph(graph); } void transferInputOutputBackends(script::Module& module) { std::shared_ptr graph = module.get_methods()[0].graph(); transferInputOutputBackends(graph); } void eliminateDeadCode(script::Module& module) { for (auto& method : module.get_methods()) { EliminateDeadCode(method.graph()); } } void rewriteQuantizedOps(std::shared_ptr& graph) { // quantized::add std::string quantized_add_pattern = R"( graph(%a_quant, %b_quant, %r_scale, %r_zero_point) : %res = quantized::add(%a_quant, %b_quant, %r_scale, %r_zero_point) return (%res) )"; std::string vk_quantized_add_pattern = R"( graph(%a_quant, %b_quant, %r_scale, %r_zero_point) : %res = vulkan_quantized::add(%a_quant, %b_quant, %r_scale, %r_zero_point) return (%res) )"; torch::jit::SubgraphRewriter quantized_add_rewriter; quantized_add_rewriter.RegisterRewritePattern( quantized_add_pattern, vk_quantized_add_pattern); quantized_add_rewriter.runOnGraph(graph); // quantized::mul std::string quantized_mul_pattern = R"( graph(%a_quant, %b_quant, %r_scale, %r_zero_point) : %res = quantized::mul(%a_quant, %b_quant, %r_scale, %r_zero_point) return (%res) )"; std::string vk_quantized_mul_pattern = R"( graph(%a_quant, %b_quant, %r_scale, %r_zero_point) : %res = vulkan_quantized::mul(%a_quant, %b_quant, %r_scale, %r_zero_point) return (%res) )"; torch::jit::SubgraphRewriter quantized_mul_rewriter; quantized_mul_rewriter.RegisterRewritePattern( quantized_mul_pattern, vk_quantized_mul_pattern); quantized_mul_rewriter.runOnGraph(graph); // quantized::conv2d std::string quantized_conv2d_pattern = R"( graph(%a_quant, %packed_params, %r_scale, %r_zero_point) : %res = quantized::conv2d(%a_quant, %packed_params, %r_scale, %r_zero_point) return (%res) )"; std::string vk_quantized_conv2d_pattern = R"( graph(%a_quant, %packed_params, %r_scale, %r_zero_point): %output_min_max : None = prim::Constant() %vk_packed_params : __torch__.torch.classes.vulkan.Conv2dPackedContext = vulkan_quantized_prepack::convert_qconv2d_context( %packed_params, %output_min_max, %output_min_max) %res = vulkan_prepack::run_qconv2d_context( %a_quant, %r_scale, %r_zero_point, %vk_packed_params) return (%res) )"; torch::jit::SubgraphRewriter quantized_conv2d_rewriter; quantized_conv2d_rewriter.RegisterRewritePattern( quantized_conv2d_pattern, vk_quantized_conv2d_pattern); quantized_conv2d_rewriter.runOnGraph(graph); // quantized::conv_transpose2d std::string quantized_conv_transpose2d_pattern = R"( graph(%a_quant, %packed_params, %r_scale, %r_zero_point) : %res = quantized::conv_transpose2d(%a_quant, %packed_params, %r_scale, %r_zero_point) return (%res) )"; std::string vk_quantized_conv_transpose2d_pattern = R"( graph(%a_quant, %packed_params, %r_scale, %r_zero_point): %output_min_max : None = prim::Constant() %vk_packed_params : __torch__.torch.classes.vulkan.Conv2dPackedContext = vulkan_quantized_prepack::convert_qtconv2d_context( %packed_params, %output_min_max, %output_min_max) %res = vulkan_prepack::run_qconv2d_context( %a_quant, %r_scale, %r_zero_point, %vk_packed_params) return (%res) )"; torch::jit::SubgraphRewriter quantized_conv_transpose2d_rewriter; quantized_conv_transpose2d_rewriter.RegisterRewritePattern( quantized_conv_transpose2d_pattern, vk_quantized_conv_transpose2d_pattern); quantized_conv_transpose2d_rewriter.runOnGraph(graph); // quantized::conv2d_relu std::string quantized_conv2d_relu_pattern = R"( graph(%a_quant, %packed_params, %r_scale, %r_zero_point) : %res = quantized::conv2d_relu(%a_quant, %packed_params, %r_scale, %r_zero_point) return (%res) )"; std::string vk_quantized_conv2d_relu_pattern = R"( graph(%a_quant, %packed_params, %r_scale, %r_zero_point): %output_min: float = prim::Constant[value=0.0]() %output_max: None = prim::Constant() %vk_packed_params : __torch__.torch.classes.vulkan.Conv2dPackedContext = vulkan_quantized_prepack::convert_qconv2d_context( %packed_params, %output_min, %output_max) %res = vulkan_prepack::run_qconv2d_context( %a_quant, %r_scale, %r_zero_point, %vk_packed_params) return (%res) )"; torch::jit::SubgraphRewriter quantized_conv2d_relu_rewriter; quantized_conv2d_relu_rewriter.RegisterRewritePattern( quantized_conv2d_relu_pattern, vk_quantized_conv2d_relu_pattern); quantized_conv2d_relu_rewriter.runOnGraph(graph); // quantized::linear std::string quantized_linear_pattern = R"( graph(%a_quant, %packed_params, %r_scale, %r_zero_point) : %res = quantized::linear(%a_quant, %packed_params, %r_scale, %r_zero_point) return (%res) )"; std::string vk_quantized_linear_pattern = R"( graph(%a_quant, %packed_params, %r_scale, %r_zero_point): %vk_packed_params : __torch__.torch.classes.vulkan.LinearPackedContext = vulkan_quantized_prepack::convert_linear_context( %packed_params) %res = vulkan_prepack::run_qlinear_context( %a_quant, %r_scale, %r_zero_point, %vk_packed_params) return (%res) )"; torch::jit::SubgraphRewriter quantized_linear_rewriter; quantized_linear_rewriter.RegisterRewritePattern( quantized_linear_pattern, vk_quantized_linear_pattern); quantized_linear_rewriter.runOnGraph(graph); } void insertPrePackedGruOp(std::shared_ptr& graph) { std::string gru_pattern = R"( graph(%input.1, %hx.1, %params_cpu:Tensor[], %has_biases:bool, %num_layers:int, %dropout:float, %train:bool, %bidirectional:bool, %batch_first:bool): %y.1 : Tensor, %hn.1 : Tensor = aten::gru(%input.1, %hx.1, %params_cpu, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first) return (%y.1, %hn.1) )"; std::string prepacked_ops_pattern = R"( graph(%input.1, %hx.1, %params_cpu:Tensor[], %has_biases:bool, %num_layers:int, %dropout:float, %train:bool, %bidirectional:bool, %batch_first:bool): %packed_weights_biases = vulkan_prepack::create_gru_context( %params_cpu, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first) %y.1 : Tensor, %hn.1 : Tensor = vulkan_prepack::run_gru_context(%input.1, %hx.1, %packed_weights_biases) return (%y.1, %hn.1) )"; auto filter = [&](const Match& match, const std::unordered_map& vmap) { auto node = match.values_map.at(vmap.at("params_cpu"))->node(); return node->output()->type()->str() == "Tensor[]"; }; SubgraphRewriter gru_rewriter; gru_rewriter.RegisterRewritePattern(gru_pattern, prepacked_ops_pattern); gru_rewriter.runOnGraph(graph, filter); } void insertPrePackedLstmOp(std::shared_ptr& graph) { std::string lstm_pattern = R"( graph(%input.1, %hx:Tensor[], %params_cpu:Tensor[], %has_biases:bool, %num_layers:int, %dropout:float, %train:bool, %bidirectional:bool, %batch_first:bool): %y.1 : Tensor, %hn.1 : Tensor, %cn.1 : Tensor = aten::lstm(%input.1, %hx, %params_cpu, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first) return (%y.1, %hn.1, %cn.1) )"; std::string prepacked_ops_pattern = R"( graph(%input.1, %hx:Tensor[], %params_cpu:Tensor[], %has_biases:bool, %num_layers:int, %dropout:float, %train:bool, %bidirectional:bool, %batch_first:bool): %packed_weights_biases = vulkan_prepack::create_lstm_context( %params_cpu, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first) %hx.1 : Tensor, %cx.1 : Tensor = prim::ListUnpack(%hx) %y.1 : Tensor, %hn.1 : Tensor, %cn.1 : Tensor = vulkan_prepack::run_lstm_context(%input.1, %hx.1, %cx.1, %packed_weights_biases) return (%y.1, %hn.1, %cn.1) )"; auto filter = [&](const Match& match, const std::unordered_map& vmap) { auto node = match.values_map.at(vmap.at("hx"))->node(); return node->output()->type()->str() == "Tensor[]"; }; SubgraphRewriter lstm_rewriter; lstm_rewriter.RegisterRewritePattern(lstm_pattern, prepacked_ops_pattern); lstm_rewriter.runOnGraph(graph, filter); } void fuseHardtanhWithPackedOps(std::shared_ptr& graph) { SubgraphRewriter rewriter; std::string conv2d_prepack_run_hardtanh_fused = R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max): %packed_weight_bias : __torch__.torch.classes.vulkan.Conv2dPackedContext = vulkan_prepack::create_conv2d_context( %weight, %bias, %stride, %padding, %dilation, %groups, %output_min, %output_max) %r = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias) return (%r) )"; std::string conv2d_prepack_run_hardtanh = R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max): %packed_weight_bias = vulkan_prepack::create_conv2d_context( %weight, %bias, %stride, %padding, %dilation, %groups, %dummy_min_max, %dummy_min_max) %conv2d_res = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias) %r = aten::hardtanh(%conv2d_res, %output_min, %output_max) return (%r) )"; rewriter.RegisterRewritePattern( conv2d_prepack_run_hardtanh, conv2d_prepack_run_hardtanh_fused); std::string conv2d_prepack_run_hardtanh_inplace = R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max): %packed_weight_bias = vulkan_prepack::create_conv2d_context( %weight, %bias, %stride, %padding, %dilation, %groups, %dummy_min_max, %dummy_min_max) %conv2d_res = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias) %r = aten::hardtanh_(%conv2d_res, %output_min, %output_max) return (%r) )"; rewriter.RegisterRewritePattern( conv2d_prepack_run_hardtanh_inplace, conv2d_prepack_run_hardtanh_fused); rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable); } void fuseReluWithPackedOps(std::shared_ptr& graph) { SubgraphRewriter rewriter; std::string conv2d_prepack_run_relu_fused = R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int, %dummy_min_max): %output_min: float = prim::Constant[value=0.0]() %output_max: None = prim::Constant() %packed_weight_bias : __torch__.torch.classes.vulkan.Conv2dPackedContext = vulkan_prepack::create_conv2d_context( %weight, %bias, %stride, %padding, %dilation, %groups, %output_min, %output_max) %r = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias) return (%r) )"; std::string conv2d_prepack_run_relu = R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int, %dummy_min_max): %packed_weight_bias = vulkan_prepack::create_conv2d_context( %weight, %bias, %stride, %padding, %dilation, %groups, %dummy_min_max, %dummy_min_max) %conv2d_res = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias) %r = aten::relu(%conv2d_res) return (%r) )"; rewriter.RegisterRewritePattern( conv2d_prepack_run_relu, conv2d_prepack_run_relu_fused); std::string conv2d_prepack_run_relu_inplace = R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int, %dummy_min_max): %packed_weight_bias = vulkan_prepack::create_conv2d_context( %weight, %bias, %stride, %padding, %dilation, %groups, %dummy_min_max, %dummy_min_max) %conv2d_res = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias) %r = aten::relu_(%conv2d_res) return (%r) )"; rewriter.RegisterRewritePattern( conv2d_prepack_run_relu_inplace, conv2d_prepack_run_relu_fused); rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable); } } // namespace void vulkanInsertPrePackedOps(std::shared_ptr& graph) { insertPrePackedLinearOp(graph); insertPrePackedLayernormOp(graph); insertPrePackedConv2dOp(graph); insertPrePackedConv1dOp(graph); rewriteQuantizedOps(graph); insertPrePackedGruOp(graph); insertPrePackedLstmOp(graph); insertPrePackedBatchNormOp(graph); } void vulkanInsertPrePackedOps(script::Module& module) { for (auto& method : module.get_methods()) { auto graph = method.graph(); vulkanInsertPrePackedOps(graph); } for (script::Module m : module.children()) { vulkanInsertPrePackedOps(m); } } void vulkanFusePrePackedConvWithClamp(script::Module& module) { auto graph = module.get_method("forward").graph(); fuseReluWithPackedOps(graph); fuseHardtanhWithPackedOps(graph); } void vulkanFoldPrePackingOps(script::Module& m) { PrePackingOpsFilterFn filter_fn = [](const Node* n) -> bool { return ( (n->kind() == Symbol::fromQualString("vulkan_prepack::create_conv2d_context")) || (n->kind() == Symbol::fromQualString("vulkan_prepack::create_tconv2d_context")) || (n->kind() == Symbol::fromQualString("vulkan_prepack::create_qconv2d_context")) || (n->kind() == Symbol::fromQualString("vulkan_prepack::create_qtconv2d_context")) || (n->kind() == Symbol::fromQualString( "vulkan_quantized_prepack::convert_qconv2d_context")) || (n->kind() == Symbol::fromQualString("vulkan_prepack::create_conv1d_context")) || (n->kind() == Symbol::fromQualString( "vulkan_quantized_prepack::convert_qtconv2d_context")) || (n->kind() == Symbol::fromQualString( "vulkan_quantized_prepack::convert_linear_context")) || (n->kind() == Symbol::fromQualString("vulkan_prepack::create_linear_context")) || (n->kind() == Symbol::fromQualString("vulkan_prepack::create_layernorm_context")) || (n->kind() == Symbol::fromQualString("vulkan_prepack::create_gru_context")) || (n->kind() == Symbol::fromQualString("vulkan_prepack::create_lstm_context")) || (n->kind() == Symbol::fromQualString("vulkan_prepack::create_batchnorm_context"))); }; PrePackingOpsFolder(m, filter_fn, "prepack_folding"); } static void vulkanRemoveMutation(script::Module& module) { auto graph = module.get_method("forward").graph(); RemoveTensorMutation(graph); } static void vulkanRunCanonicalOptimizations(script::Module& module) { auto graph = module.get_method("forward").graph(); for (const auto& method : module.get_methods()) { auto method_graph = method.graph(); runOptimization(method_graph, false /* no loop unrolling */); } } script::Module vulkanOptimizeForMobile( const script::Module& m, const std::set& optimization_blocklist, const std::vector& preserved_methods) { auto cloned_module = m.clone(); cloned_module.eval(); cloned_module = FoldConvBatchNorm(cloned_module); cloned_module = freeze_module(cloned_module, preserved_methods); vulkanInsertPrePackedOps(cloned_module); vulkanFusePrePackedConvWithClamp(cloned_module); vulkanFoldPrePackingOps(cloned_module); removeDropout(cloned_module); vulkanRemoveMutation(cloned_module); if (!optimization_blocklist.count( MobileOptimizerType::VULKAN_AUTOMATIC_GPU_TRANSFER)) { transferInputOutputBackends(cloned_module); cloned_module.register_attribute( "requires_backend_transfers", BoolType::get(), false); } // remove duplicated constants vulkanRunCanonicalOptimizations(cloned_module); eliminateDeadCode(cloned_module); cloned_module.register_attribute( "optimized_for_vulkan", BoolType::get(), true); return cloned_module; } } // namespace torch::jit