#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch { namespace jit { namespace { void insertPrepackUnpackForLinear(std::shared_ptr& graph) { std::vector patterns_and_replacements = linear_prepack_unpack_patterns(); for (const auto& entry : patterns_and_replacements) { SubgraphRewriter rewriter; rewriter.RegisterRewritePattern(entry.pattern, entry.replacement); rewriter.runOnGraph(graph, entry.filters); } } void insertPrepackUnpackForConv(std::shared_ptr& graph) { std::vector patterns_and_replacements = conv_prepack_unpack_patterns(); for (const auto& entry : patterns_and_replacements) { SubgraphRewriter rewriter; rewriter.RegisterRewritePattern(entry.pattern, entry.replacement); rewriter.runOnGraph(graph, entry.filters); } } void removePackedParamInsertionAndFPWeightsSetAttr( std::shared_ptr& g, const std::unordered_set& packed_param_attr_names) { DepthFirstGraphNodeIterator it(g); Node* n = nullptr; std::vector nodes_to_delete; while ((n = it.next()) != nullptr) { if (n->kind() == prim::SetAttr) { const std::string& attr_name = n->s(attr::name); if (packed_param_attr_names.count(attr_name)) { nodes_to_delete.push_back(n); } else { Value* v = n->input(0); Value* self = g->inputs()[0]; std::vector paths = getModuleAccessPath(v, self); std::string path = joinPaths(paths); if (packed_param_attr_names.count(path)) { nodes_to_delete.push_back(n); } } } } for (auto node : nodes_to_delete) { node->removeAllInputs(); } for (auto node : nodes_to_delete) { node->destroy(); } ConstantPooling(g); EliminateDeadCode(g); } void removeObserverCallMethods(std::shared_ptr& g) { DepthFirstGraphNodeIterator it(g); Node* n = nullptr; std::vector nodes_to_delete; while ((n = it.next()) != nullptr) { if (n->kind() == prim::CallMethod) { const std::string& attr_name = n->s(attr::name); if (attr_name == "calculate_qparams") { auto observer_node = n->input(0)->node(); if (observer_node->kind() == prim::GetAttr && observer_node->s(attr::name).find("_observer_") != std::string::npos) { nodes_to_delete.push_back(n); } } } } for (auto node : nodes_to_delete) { node->removeAllInputs(); } for (auto node : nodes_to_delete) { node->destroy(); } EliminateDeadCode(g); } void keepOnlyPackedParamsGeneration(Module& m, const std::string& method_name) { auto g = m.get_method(method_name).graph(); Function& function = m.get_method(method_name).function(); const auto& schema = function.getSchema(); auto new_schema = schema.cloneWithReturns({Argument("", NoneType::get())}); for (size_t i = 0, output_size = g->outputs().size(); i < output_size; i++) { g->eraseOutput(i); } Node* none_node = g->createNone(); g->registerOutput(none_node->output()); none_node->insertBefore(g->return_node()); function.setSchema(std::move(new_schema)); EliminateDeadCode(g); } } // namespace void QuantFusion(std::shared_ptr& graph, QuantType quant_type) { std::vector patterns; if (quant_type == QuantType::DYNAMIC) { patterns = dynamic_quant_fusion_pattern_and_replacements(); std::vector patterns_wo_dynamic_activation_quant = dynamic_quantized_linear_pattern_and_replacements(); patterns.insert( patterns.end(), patterns_wo_dynamic_activation_quant.begin(), patterns_wo_dynamic_activation_quant.end()); } else { patterns = quant_fusion_pattern_and_replacements(); } for (const auto& info : patterns) { SubgraphRewriter rewriter; rewriter.RegisterRewritePattern(info.pattern, info.replacement); rewriter.runOnGraph(graph, info.filters); } } void InsertPrepackUnpack(std::shared_ptr& graph) { insertPrepackUnpackForLinear(graph); insertPrepackUnpackForConv(graph); } void InsertPrepackUnpack(Module& module) { for (auto& method : module.get_methods()) { auto graph = method.graph(); InsertPrepackUnpack(graph); } for (Module m : module.children()) { InsertPrepackUnpack(m); } } void FoldQuantizedPrepackingOps(Module& module) { auto filter_fn = [](const Node* n) -> bool { return ( n->kind() == Symbol::fromQualString("quantized::linear_prepack") || n->kind() == Symbol::fromQualString("quantized::conv1d_prepack") || n->kind() == Symbol::fromQualString("quantized::conv2d_prepack") || n->kind() == Symbol::fromQualString("quantized::conv3d_prepack") || n->kind() == Symbol::fromQualString("quantized::conv_transpose1d_prepack") || n->kind() == Symbol::fromQualString("quantized::conv_transpose2d_prepack")); }; PrePackingOpsFolder(module, filter_fn, "quantized"); } static std::unordered_set RegisterPrePackingParams( Module& module, const std::string& method_name) { auto filter_fn = [](const Node* n) -> bool { return ( n->kind() == Symbol::fromQualString("quantized::linear_prepack") || n->kind() == Symbol::fromQualString("quantized::conv1d_prepack") || n->kind() == Symbol::fromQualString("quantized::conv2d_prepack") || n->kind() == Symbol::fromQualString("quantized::conv3d_prepack") || n->kind() == Symbol::fromQualString("quantized::conv_transpose1d_prepack") || n->kind() == Symbol::fromQualString("quantized::conv_transpose2d_prepack")); }; return RegisterPrePackParams(module, method_name, filter_fn, ""); } Module Finalize( Module& module, QuantType quant_type, const std::vector& preserved_attrs) { // Tracing annotates the resulting graph with shape information. In many case, // user applies different input shapes to traced graph. It is on the user to // know it is correct to do so. The quantized module needs to be clean up and // To prevent the JIT optimizations from leveraging the annotated shape info, // clear shape information in the graph. for (auto func : module.type()->methods()) { ClearProfilingInformation(toGraphFunction(*func).graph()); } auto graph = module.get_method("forward").graph(); InsertPrepackUnpack(graph); GRAPH_DUMP("Before QuantFusion:", graph); QuantFusion(graph, quant_type); auto frozen = freeze_module(module, preserved_attrs); FoldQuantizedPrepackingOps(frozen); return frozen; } Module FinalizeOnDevicePTQ( Module& module, QuantType quant_type, const std::string& method_name) { // Tracing annotates the resulting graph with shape information. In many case, // user applies different input shapes to traced graph. It is on the user to // know it is correct to do so. The quantized module needs to be clean up and // To prevent the JIT optimizations from leveraging the annotated shape info, // clear shape information in the graph. for (auto func : module.type()->methods()) { ClearProfilingInformation(toGraphFunction(*func).graph()); } const std::string kQuantizeString = "quantize_"; const auto matched_pos = method_name.find(kQuantizeString); const auto end_pos = matched_pos + kQuantizeString.length(); const std::string orig_method_name = method_name.substr(end_pos); TORCH_CHECK( matched_pos == 0, "Quantized ops can only be added to quantize_", orig_method_name, ". Please make sure to run quant/dequant nodes insertion step for on-device PTQ."); const std::string quantized_method_name = "quantized_" + orig_method_name; auto graph = module.get_method(method_name).graph(); // Doing some AOT optimizations here // Of all CSE seems to be required otherwise in some experiments // serialized model is incorrect. As in it cannot be deserialized // Rest are included as canonical optimizations that are not for inference EliminateCommonSubexpression(graph); EliminateDeadCode(graph); PeepholeOptimize(graph); ConstantPropagation(graph); UnrollConstantLoops(graph); ConstantPooling(graph); InsertPrepackUnpack(graph); GRAPH_DUMP("Before QuantFusion:", graph); QuantFusion(graph, quant_type); auto packed_param_attr_names = RegisterPrePackingParams(module, method_name); GRAPH_DUMP("After QuantFusion + packed param registration:", graph); // Now we have: // 1. Inserted quantized weights packed params // 2. Inserted packed params to module // 3. Inserted quantized op // The next thing we need is: // 1. Replicate this method in quantize_forward // 2. Remove SetAttr for fp weights that are reset by quantize_forward // 3. Remove SetAttr node which will subsequently optimize away the nodes // producing packed_params // 4. Modify quantized_forward to remove all the nodes except for SetAttrs cloneMethod(module, method_name, quantized_method_name); // removeWeightSetAttrs(module, quantized_method_name); auto quantized_graph = module.get_method(quantized_method_name).graph(); removePackedParamInsertionAndFPWeightsSetAttr( quantized_graph, packed_param_attr_names); // Removing packed params is not sufficient since that does not do DCE // for observer node's getatts and callmethods because callmethods have side // effects removeObserverCallMethods(quantized_graph); // This step removed the return output from the graph and subsequent // DCE removes all the ops. After that only remaining things should be // packed_params keepOnlyPackedParamsGeneration(module, method_name); return module; } } // namespace jit } // namespace torch