#include #include #include #include #include #include #include #include #include namespace torch::jit { #if AT_MKLDNN_ENABLED() static c10::VaryingShape getSizesOf(Node* n, size_t idx) { auto tt = n->input(idx)->type()->cast(); return tt->sizes(); } static void insertPrePackedConvOpForNode(Node* n) { constexpr int POS_INPUT = 0; constexpr int POS_WEIGHT = 1; if (!tensorexpr::isContiguous( n->input(POS_INPUT), at::MemoryFormat::ChannelsLast)) { GRAPH_DEBUG( "insertPrePackedConvOpForNode: input is not ChannelsLast contiguous"); return; } if (!tensorexpr::isContiguous( n->input(POS_WEIGHT), at::MemoryFormat::ChannelsLast)) { GRAPH_DEBUG( "insertPrePackedConvOpForNode: weight is not ChannelsLast contiguous"); return; } // Leave depthwise conv2d to NNC if (tensorexpr::conv2dIsSupportedJit(n)) { GRAPH_DEBUG("insertPrePackedConvOpForNode: leave depthwise conv2d to NNC"); return; } WithInsertPoint guard(n); auto graph = n->owningGraph(); auto input_sizes = getSizesOf(n, POS_INPUT); IValue input_size_value(*input_sizes.concrete_sizes()); auto input_size = graph->insertConstant(input_size_value); auto prepack_node = graph->create( Symbol::fromQualString("mkldnn_prepacked::conv2d_prepack"), 1); // skip input value for (const auto i : c10::irange(1, n->inputs().size())) { Value* v = n->input(i); prepack_node->addInput(v); } prepack_node->addInput(input_size); auto attr = graph->insertConstant(IValue("none")); prepack_node->addInput(attr); prepack_node->output()->setType( getCustomClass("__torch__.torch.classes.mkldnn.ConvOpContext")); graph->insertNode(prepack_node); auto prepack_conv = graph->insertNode( graph->create(Symbol::fromQualString("mkldnn_prepacked::conv2d_run"), 1)); prepack_conv->addInput(n->input(0)); prepack_conv->addInput(prepack_node->output()); prepack_conv->output()->setType(n->output()->type()->cast()); n->output()->replaceAllUsesWith(prepack_conv->output()); } static bool isTensorTypeCPU(Node* node) { for (const auto& input : node->inputs()) { auto type = input->type()->cast(); if (!type) { continue; } auto device = type->device(); if (!device) { return false; } if (!device->is_cpu()) { return false; } } return true; } static void insertPrePackedConvOp(Block* b) { for (Node* n : b->nodes()) { for (Block* b : n->blocks()) { insertPrePackedConvOp(b); } if (n->kind() == aten::conv2d) { if (isTensorTypeCPU(n)) { insertPrePackedConvOpForNode(n); } } } EliminateDeadCode(b); } static void insertMkldnnPrePackedConv2dOp(std::shared_ptr& graph) { insertPrePackedConvOp(graph->block()); } static void insertMkldnnPrePackedOps(std::shared_ptr& graph) { insertMkldnnPrePackedConv2dOp(graph); } static void FuseReluWithPackedOps(std::shared_ptr& graph) { auto conv_op_rstring = at::jit::CodeTemplate(R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int, %input_size:int[], %dummy_attr:str): %packed_weight_bias = mkldnn_prepacked::conv2d_prepack( %weight, %bias, %stride, %padding, %dilation, %groups, %input_size, %dummy_attr) %conv2d_res = mkldnn_prepacked::conv2d_run(%input, %packed_weight_bias) %res = aten::${op}(%conv2d_res) return (%res))"); auto conv_op_fused_rstring = at::jit::CodeTemplate(R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int, %input_size:int[], %dummy_attr:str): %attr: str = prim::Constant[value="${op_attr}"]() %packed_weight_bias : __torch__.torch.classes.mkldnn.ConvOpContext = mkldnn_prepacked::conv2d_prepack( %weight, %bias, %stride, %padding, %dilation, %groups, %input_size, %attr) %res = mkldnn_prepacked::conv2d_run(%input, %packed_weight_bias) return (%res))"); for (auto const& it : mkldnn::fusion_rewrite_map) { std::string op = it.first; if (op == std::string("none")) { continue; } at::jit::TemplateEnv env; env.s("op", op); at::jit::TemplateEnv env_fused; env_fused.s("op_attr", op); SubgraphRewriter rewriter; rewriter.RegisterRewritePattern( conv_op_rstring.format(env), conv_op_fused_rstring.format(env_fused)); auto filters = it.second; rewriter.runOnGraph(graph, filters); } } static void PrePackingOpsFolder(Block* b) { auto is_foldable_op = [](const Node* n) -> bool { return ( n->kind() == Symbol::fromQualString("mkldnn_prepacked::conv2d_prepack")); }; std::unordered_set nodes_to_delete; for (Node* n : b->nodes()) { for (Block* block : n->blocks()) { PrePackingOpsFolder(block); } if (is_foldable_op(n)) { auto optional_outputs = torch::jit::runNodeIfInputsAreConstant(n); if (optional_outputs) { auto outputs = optional_outputs.value(); TORCH_CHECK(outputs.size() == 1, "Prepack ops have single output"); Value* prepack_op_value = n->output(0); auto graph = n->owningGraph(); WithInsertPoint ins(prepack_op_value->node()); auto weak_class_obj = outputs[0].toObject()->copy_to_weak_compilation_ref(); Value* packed_weight = graph->insertConstant(weak_class_obj) ->setType(n->output(0)->type()); prepack_op_value->replaceAllUsesWith(packed_weight); nodes_to_delete.insert(n); } } } for (auto n : nodes_to_delete) { n->removeAllInputs(); } for (auto n : nodes_to_delete) { n->destroy(); } } static void FoldPrePackingOps(std::shared_ptr& graph) { PrePackingOpsFolder(graph->block()); } void FuseConvWithEltwise(std::shared_ptr& graph) { GRAPH_DEBUG( "Before insertMkldnnPrePackedOps. Beginning of FuseConvWithEltwise\n", *graph); insertMkldnnPrePackedOps(graph); GRAPH_DEBUG( "After insertMkldnnPrePackedOps, before FuseReluWithPackedOps\n", *graph); FuseReluWithPackedOps(graph); GRAPH_DEBUG( "After FuseReluWithPackedOps, before FoldPrePackingOps\n", *graph); FoldPrePackingOps(graph); GRAPH_DEBUG("After FoldPrePackingOps. End of FuseConvWithEltwise\n", *graph); } #else void FuseConvWithEltwise(std::shared_ptr& graph) { GRAPH_DEBUG("MKLDNN Not enabled"); } #endif // AT_MKLDNN_ENABLED() } // namespace torch::jit