1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "tools/converter/graphdef_transform.h" 18 #include <string> 19 #include <algorithm> 20 #include "schema/model_generated.h" 21 #include "src/common/log_adapter.h" 22 #include "src/common/log_util.h" 23 #include "tools/converter/converter_flags.h" 24 #include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h" 25 #include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h" 26 #include "tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.h" 27 #include "tools/converter/legacy_optimizer/graph/infershape_pass.h" 28 #include "tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.h" 29 #include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h" 30 #include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h" 31 #include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h" 32 #include "tools/converter/legacy_optimizer/graph/tensor_quant_pass.h" 33 #include "tools/converter/legacy_optimizer/graph/tensor_name_pass.h" 34 #include "tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h" 35 #include "tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.h" 36 #include "tools/converter/legacy_optimizer/graph/convert_fp32_to_fp16_pass.h" 37 #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h" 38 #include "tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.h" 39 40 using std::string; 41 namespace mindspore::lite { GetGraphNodes()42std::vector<schema::CNodeT *> GraphDefTransform::GetGraphNodes() { 43 std::vector<schema::CNodeT *> old_nodes{}; 44 old_nodes.resize(graph_defT_->nodes.size()); 45 std::transform(graph_defT_->nodes.begin(), graph_defT_->nodes.end(), old_nodes.begin(), 46 [](const std::unique_ptr<schema::CNodeT> &node) { return node.get(); }); 47 return old_nodes; 48 } 49 50 GraphDefTransform::GraphDefTransform() = default; 51 52 GraphDefTransform::~GraphDefTransform() = default; 53 SetGraphDef(schema::MetaGraphT * dst_def)54void GraphDefTransform::SetGraphDef(schema::MetaGraphT *dst_def) { graph_defT_ = dst_def; } 55 Transform(const converter::Flags & ctx)56int GraphDefTransform::Transform(const converter::Flags &ctx) { 57 STATUS status; 58 { 59 auto old_nodes = GetGraphNodes(); 60 Optimizer unused_op_remove_optimizer; 61 if (!ctx.trainModel) { 62 unused_op_remove_optimizer.AddPass(new DropoutNodeRemovePass()); 63 } 64 unused_op_remove_optimizer.AddPass(new IsolatedNodeRemovePass()); 65 unused_op_remove_optimizer.AddPass(new SubgraphNodePass(old_nodes)); 66 status = unused_op_remove_optimizer.Run(graph_defT_); 67 if (status != RET_OK && status != RET_NO_CHANGE) { 68 MS_LOG(ERROR) << "Run unused_op_remove_optimizer graphPasses Failed"; 69 return status; 70 } 71 } 72 73 // format transpose global optimize 74 { 75 // init old node indices 76 auto old_nodes = GetGraphNodes(); 77 Optimizer format_trans_optimizer; 78 if (!ctx.trainModel && ctx.fmk != converter::kFmkTypeOnnx) { 79 format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); 80 format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); 81 } 82 status = format_trans_optimizer.Run(graph_defT_); 83 if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { 84 MS_LOG(ERROR) << "Run format_trans_optimizer graphPasses Failed"; 85 return status; 86 } 87 } 88 89 // node replace 90 if (!ctx.trainModel) { 91 // init old node indices 92 auto old_nodes = GetGraphNodes(); 93 Optimizer replace_optimizer; 94 replace_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk)); 95 replace_optimizer.AddPass(new (std::nothrow) BatchNormConvertScalePass(ctx.fmk)); 96 replace_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); 97 replace_optimizer.AddPass(new SubgraphNodePass(old_nodes)); 98 status = replace_optimizer.Run(graph_defT_); 99 if (status != RET_OK && status != RET_NO_CHANGE) { 100 MS_LOG(ERROR) << "Run replace_optimizer BatchNormConvertScalePass Failed"; 101 return status; 102 } 103 } 104 105 // node fusion 106 { 107 // init old node indices 108 auto old_nodes = GetGraphNodes(); 109 Optimizer fusion_optimizer; 110 fusion_optimizer.AddPass(new (std::nothrow) MulAddFusionPass()); 111 fusion_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); 112 fusion_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); 113 status = fusion_optimizer.Run(graph_defT_); 114 if (status != RET_OK && status != RET_NO_CHANGE) { 115 MS_LOG(ERROR) << "Run fusion_optimizer graphPasses Failed"; 116 return status; 117 } 118 } 119 120 // quantization 121 if (ctx.fmk != converter::kFmkTypeTf) { 122 // init old node indices 123 auto old_nodes = GetGraphNodes(); 124 Optimizer tensor_quant_optimizer; 125 tensor_quant_optimizer.AddPass(new (std::nothrow) TopologicalSortPass()); 126 tensor_quant_optimizer.AddPass(new (std::nothrow) InferQuantParamPass()); 127 tensor_quant_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk)); 128 tensor_quant_optimizer.AddPass(new (std::nothrow) TensorQuantPass()); 129 tensor_quant_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); 130 status = tensor_quant_optimizer.Run(graph_defT_); 131 if (status != RET_OK) { 132 MS_LOG(ERROR) << "DoQuantize failed!"; 133 return status; 134 } 135 } 136 137 // quantization 138 if (ctx.fmk != converter::kFmkTypeTf) { 139 // init old node indices 140 Optimizer quant_node_optimizer; 141 quant_node_optimizer.AddPass(new (std::nothrow) TopologicalSortPass()); 142 auto old_nodes = GetGraphNodes(); 143 quant_node_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk)); 144 quant_node_optimizer.AddPass(new (std::nothrow) DTypeTransPass(ctx.inputDataType, ctx.outputDataType)); 145 quant_node_optimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); 146 quant_node_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); 147 quant_node_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); 148 status = quant_node_optimizer.Run(graph_defT_); 149 if (status != RET_OK && status != RET_NO_CHANGE) { 150 MS_LOG(ERROR) << "Run quant_node_optimizer graphPasses Failed"; 151 return status; 152 } 153 } 154 155 { 156 // init old node indices 157 auto old_nodes = GetGraphNodes(); 158 Optimizer switch_optimizer; 159 switch_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); 160 switch_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); 161 switch_optimizer.AddPass(new (std::nothrow) SubgraphTensorPass()); 162 status = switch_optimizer.Run(graph_defT_); 163 if (status != RET_OK && status != RET_NO_CHANGE) { 164 MS_LOG(ERROR) << "Run switch_optimizer Failed"; 165 return status; 166 } 167 } 168 169 { 170 Optimizer nested_loop_optimizer; 171 auto old_nodes = GetGraphNodes(); 172 nested_loop_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); 173 nested_loop_optimizer.AddPass(new (std::nothrow) TopologicalSortPass()); 174 status = nested_loop_optimizer.Run(graph_defT_); 175 if (status != RET_OK && status != RET_NO_CHANGE) { 176 MS_LOG(ERROR) << "Run nested_loop_optimizer graphPasses Failed"; 177 return status; 178 } 179 } 180 181 { 182 Optimizer forming_model_optimizer; 183 forming_model_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk)); 184 forming_model_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass()); 185 forming_model_optimizer.AddPass(new (std::nothrow) TensorNamePass()); 186 forming_model_optimizer.AddPass(new (std::nothrow) ConvertFP32ToFP16Pass(ctx.saveFP16)); 187 status = forming_model_optimizer.Run(graph_defT_); 188 if (status != RET_OK) { 189 MS_LOG(ERROR) << "Run InferShapeOptimizer graphPasses Failed."; 190 return status; 191 } 192 } 193 return RET_OK; 194 } 195 } // namespace mindspore::lite 196