• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/graphdef_transform.h"
18 #include <string>
19 #include <algorithm>
20 #include "schema/model_generated.h"
21 #include "src/common/log_adapter.h"
22 #include "src/common/log_util.h"
23 #include "tools/converter/converter_flags.h"
24 #include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h"
25 #include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h"
26 #include "tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.h"
27 #include "tools/converter/legacy_optimizer/graph/infershape_pass.h"
28 #include "tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.h"
29 #include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h"
30 #include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h"
31 #include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h"
32 #include "tools/converter/legacy_optimizer/graph/tensor_quant_pass.h"
33 #include "tools/converter/legacy_optimizer/graph/tensor_name_pass.h"
34 #include "tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h"
35 #include "tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.h"
36 #include "tools/converter/legacy_optimizer/graph/convert_fp32_to_fp16_pass.h"
37 #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
38 #include "tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.h"
39 
40 using std::string;
41 namespace mindspore::lite {
GetGraphNodes()42 std::vector<schema::CNodeT *> GraphDefTransform::GetGraphNodes() {
43   std::vector<schema::CNodeT *> old_nodes{};
44   old_nodes.resize(graph_defT_->nodes.size());
45   std::transform(graph_defT_->nodes.begin(), graph_defT_->nodes.end(), old_nodes.begin(),
46                  [](const std::unique_ptr<schema::CNodeT> &node) { return node.get(); });
47   return old_nodes;
48 }
49 
50 GraphDefTransform::GraphDefTransform() = default;
51 
52 GraphDefTransform::~GraphDefTransform() = default;
53 
SetGraphDef(schema::MetaGraphT * dst_def)54 void GraphDefTransform::SetGraphDef(schema::MetaGraphT *dst_def) { graph_defT_ = dst_def; }
55 
Transform(const converter::Flags & ctx)56 int GraphDefTransform::Transform(const converter::Flags &ctx) {
57   STATUS status;
58   {
59     auto old_nodes = GetGraphNodes();
60     Optimizer unused_op_remove_optimizer;
61     if (!ctx.trainModel) {
62       unused_op_remove_optimizer.AddPass(new DropoutNodeRemovePass());
63     }
64     unused_op_remove_optimizer.AddPass(new IsolatedNodeRemovePass());
65     unused_op_remove_optimizer.AddPass(new SubgraphNodePass(old_nodes));
66     status = unused_op_remove_optimizer.Run(graph_defT_);
67     if (status != RET_OK && status != RET_NO_CHANGE) {
68       MS_LOG(ERROR) << "Run unused_op_remove_optimizer graphPasses Failed";
69       return status;
70     }
71   }
72 
73   // format transpose global optimize
74   {
75     // init old node indices
76     auto old_nodes = GetGraphNodes();
77     Optimizer format_trans_optimizer;
78     if (!ctx.trainModel && ctx.fmk != converter::kFmkTypeOnnx) {
79       format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
80       format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
81     }
82     status = format_trans_optimizer.Run(graph_defT_);
83     if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
84       MS_LOG(ERROR) << "Run format_trans_optimizer graphPasses Failed";
85       return status;
86     }
87   }
88 
89   // node replace
90   if (!ctx.trainModel) {
91     // init old node indices
92     auto old_nodes = GetGraphNodes();
93     Optimizer replace_optimizer;
94     replace_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk));
95     replace_optimizer.AddPass(new (std::nothrow) BatchNormConvertScalePass(ctx.fmk));
96     replace_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
97     replace_optimizer.AddPass(new SubgraphNodePass(old_nodes));
98     status = replace_optimizer.Run(graph_defT_);
99     if (status != RET_OK && status != RET_NO_CHANGE) {
100       MS_LOG(ERROR) << "Run replace_optimizer BatchNormConvertScalePass Failed";
101       return status;
102     }
103   }
104 
105   // node fusion
106   {
107     // init old node indices
108     auto old_nodes = GetGraphNodes();
109     Optimizer fusion_optimizer;
110     fusion_optimizer.AddPass(new (std::nothrow) MulAddFusionPass());
111     fusion_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
112     fusion_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
113     status = fusion_optimizer.Run(graph_defT_);
114     if (status != RET_OK && status != RET_NO_CHANGE) {
115       MS_LOG(ERROR) << "Run fusion_optimizer graphPasses Failed";
116       return status;
117     }
118   }
119 
120   // quantization
121   if (ctx.fmk != converter::kFmkTypeTf) {
122     // init old node indices
123     auto old_nodes = GetGraphNodes();
124     Optimizer tensor_quant_optimizer;
125     tensor_quant_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
126     tensor_quant_optimizer.AddPass(new (std::nothrow) InferQuantParamPass());
127     tensor_quant_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk));
128     tensor_quant_optimizer.AddPass(new (std::nothrow) TensorQuantPass());
129     tensor_quant_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
130     status = tensor_quant_optimizer.Run(graph_defT_);
131     if (status != RET_OK) {
132       MS_LOG(ERROR) << "DoQuantize failed!";
133       return status;
134     }
135   }
136 
137   // quantization
138   if (ctx.fmk != converter::kFmkTypeTf) {
139     // init old node indices
140     Optimizer quant_node_optimizer;
141     quant_node_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
142     auto old_nodes = GetGraphNodes();
143     quant_node_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk));
144     quant_node_optimizer.AddPass(new (std::nothrow) DTypeTransPass(ctx.inputDataType, ctx.outputDataType));
145     quant_node_optimizer.AddPass(new (std::nothrow) QuantCastFusionPass());
146     quant_node_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
147     quant_node_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
148     status = quant_node_optimizer.Run(graph_defT_);
149     if (status != RET_OK && status != RET_NO_CHANGE) {
150       MS_LOG(ERROR) << "Run quant_node_optimizer graphPasses Failed";
151       return status;
152     }
153   }
154 
155   {
156     // init old node indices
157     auto old_nodes = GetGraphNodes();
158     Optimizer switch_optimizer;
159     switch_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
160     switch_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
161     switch_optimizer.AddPass(new (std::nothrow) SubgraphTensorPass());
162     status = switch_optimizer.Run(graph_defT_);
163     if (status != RET_OK && status != RET_NO_CHANGE) {
164       MS_LOG(ERROR) << "Run switch_optimizer Failed";
165       return status;
166     }
167   }
168 
169   {
170     Optimizer nested_loop_optimizer;
171     auto old_nodes = GetGraphNodes();
172     nested_loop_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
173     nested_loop_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
174     status = nested_loop_optimizer.Run(graph_defT_);
175     if (status != RET_OK && status != RET_NO_CHANGE) {
176       MS_LOG(ERROR) << "Run nested_loop_optimizer graphPasses Failed";
177       return status;
178     }
179   }
180 
181   {
182     Optimizer forming_model_optimizer;
183     forming_model_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk));
184     forming_model_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass());
185     forming_model_optimizer.AddPass(new (std::nothrow) TensorNamePass());
186     forming_model_optimizer.AddPass(new (std::nothrow) ConvertFP32ToFP16Pass(ctx.saveFP16));
187     status = forming_model_optimizer.Run(graph_defT_);
188     if (status != RET_OK) {
189       MS_LOG(ERROR) << "Run InferShapeOptimizer graphPasses Failed.";
190       return status;
191     }
192   }
193   return RET_OK;
194 }
195 }  // namespace mindspore::lite
196