• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/graphdef_transform.h"
18 #include <string>
19 #include <algorithm>
20 #include "schema/model_generated.h"
21 #include "src/common/log_adapter.h"
22 #include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h"
23 #include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h"
24 #include "tools/converter/legacy_optimizer/graph/infershape_pass.h"
25 #include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h"
26 #include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h"
27 #include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h"
28 #include "tools/converter/legacy_optimizer/graph/tensor_name_pass.h"
29 #include "tools/converter/legacy_optimizer/graph/node_name_pass.h"
30 #include "tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.h"
31 #include "tools/converter/legacy_optimizer/graph/convert_fp32_to_fp16_pass.h"
32 #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
33 #include "tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.h"
34 #include "tools/converter/legacy_optimizer/graph/const_node_reorder_pass.h"
35 
36 using std::string;
37 namespace mindspore::lite {
38 GraphDefTransform::GraphDefTransform() = default;
39 
~GraphDefTransform()40 GraphDefTransform::~GraphDefTransform() { this->graph_defT_ = nullptr; }
41 
SetGraphDef(schema::MetaGraphT * dst_def)42 void GraphDefTransform::SetGraphDef(schema::MetaGraphT *dst_def) { graph_defT_ = dst_def; }
43 
44 namespace {
GetGraphNodes(const schema::MetaGraphT & graph_defT)45 std::vector<schema::CNodeT *> GetGraphNodes(const schema::MetaGraphT &graph_defT) {
46   std::vector<schema::CNodeT *> old_nodes{};
47   old_nodes.resize(graph_defT.nodes.size());
48   std::transform(graph_defT.nodes.begin(), graph_defT.nodes.end(), old_nodes.begin(),
49                  [](const std::unique_ptr<schema::CNodeT> &node) { return node.get(); });
50   return old_nodes;
51 }
52 
QuantTransform(const std::shared_ptr<ConverterPara> & param,schema::MetaGraphT * graph_defT)53 int QuantTransform(const std::shared_ptr<ConverterPara> &param, schema::MetaGraphT *graph_defT) {
54   MS_ASSERT(param != nullptr && graph_defT != nullptr);
55   // quantization
56   if (param->commonQuantParam.quant_type == quant::QUANT_NONE ||
57       param->commonQuantParam.quant_type == quant::QUANT_WEIGHT) {
58     {
59       // quantization
60       // init old node indices
61       Optimizer quant_node_optimizer;
62       quant_node_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
63       auto old_nodes = GetGraphNodes(*graph_defT);
64       quant_node_optimizer.AddPass(new (std::nothrow) InferShapePass(param->fmk_type));
65       quant_node_optimizer.AddPass(new (std::nothrow) DTypeTransPass(static_cast<TypeId>(param->input_data_type),
66                                                                      static_cast<TypeId>(param->output_data_type)));
67       quant_node_optimizer.AddPass(new (std::nothrow) QuantCastFusionPass());
68       quant_node_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
69       quant_node_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
70       auto status = quant_node_optimizer.Run(graph_defT);
71       if (status != RET_OK && status != RET_NO_CHANGE) {
72         MS_LOG(ERROR) << "Run quant_node_optimizer graphPasses Failed";
73         return status;
74       }
75     }
76   }
77   return RET_OK;
78 }
79 
FillGraphOutputShape(MetaGraphT * meta_graph,const std::vector<std::vector<int64_t>> output_shapes)80 int FillGraphOutputShape(MetaGraphT *meta_graph, const std::vector<std::vector<int64_t>> output_shapes) {
81   const auto &out_indices = meta_graph->outputIndex;
82   for (size_t i = 0; i < out_indices.size(); i++) {
83     auto &out_tensor = meta_graph->allTensors[out_indices[i]];
84     out_tensor->dims = {};
85     for (size_t k = 0; k < output_shapes[i].size(); k++) {
86       out_tensor->dims.push_back(static_cast<int32_t>(output_shapes[i][k]));
87     }
88   }
89   return RET_OK;
90 }
91 
FillGraphInputAndOutputFormats(MetaGraphT * meta_graph,const ConverterPara & para)92 void FillGraphInputAndOutputFormats(MetaGraphT *meta_graph, const ConverterPara &para) {
93   const auto &in_indices = meta_graph->inputIndex;
94   for (size_t i = 0; i < in_indices.size(); i++) {
95     auto &in_tensor = meta_graph->allTensors[in_indices[i]];
96     in_tensor->format = para.thirdPartyModelParam.input_formats[i];
97     MS_LOG(DEBUG) << "input " << i << " format: " << EnumNameFormat(in_tensor->format);
98   }
99 
100   const auto &out_indices = meta_graph->outputIndex;
101   for (size_t i = 0; i < out_indices.size(); i++) {
102     auto &out_tensor = meta_graph->allTensors[out_indices[i]];
103     out_tensor->format = para.thirdPartyModelParam.output_formats[i];
104     MS_LOG(DEBUG) << "output " << i << " format: " << EnumNameFormat(out_tensor->format);
105   }
106 }
107 }  // namespace
108 
Transform(const std::shared_ptr<ConverterPara> & param)109 int GraphDefTransform::Transform(const std::shared_ptr<ConverterPara> &param) {
110   MS_ASSERT(param != nullptr);
111   STATUS status;
112 
113   if (param->fmk_type == converter::kFmkTypeThirdParty) {
114 
115     // Legacy optimizer infer shape, but op Custom which wraps third party model has no infer-shape function.
116     // So we don't perform legacy optimization for kFmkTypeThirdParty case.
117     auto ret = FillGraphOutputShape(graph_defT_, param->thirdPartyModelParam.output_shapes);
118     if (ret != RET_OK) {
119       MS_LOG(ERROR) << "Fill output shape of third party model failed, ret:" << ret;
120       return ret;
121     }
122 
123     // Tensor of FuncGraph has no attribute of format, so set format in MetaGraph.
124     FillGraphInputAndOutputFormats(graph_defT_, *param);
125     return RET_OK;
126   }
127 
128   {
129     auto old_nodes = GetGraphNodes(*graph_defT_);
130     Optimizer unused_op_remove_optimizer;
131     if (!param->train_model) {
132       unused_op_remove_optimizer.AddPass(new (std::nothrow) DropoutNodeRemovePass());
133     }
134     unused_op_remove_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
135     unused_op_remove_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
136     status = unused_op_remove_optimizer.Run(graph_defT_);
137     if (status != RET_OK && status != RET_NO_CHANGE) {
138       MS_LOG(ERROR) << "Run unused_op_remove_optimizer graphPasses Failed";
139       return status;
140     }
141   }
142 
143   // format transpose global optimize
144   {
145     // init old node indices
146     auto old_nodes = GetGraphNodes(*graph_defT_);
147     Optimizer format_trans_optimizer;
148     if (!param->train_model && param->fmk_type != converter::kFmkTypeOnnx) {
149       format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
150       format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
151     }
152     status = format_trans_optimizer.Run(graph_defT_);
153     if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
154       MS_LOG(ERROR) << "Run format_trans_optimizer graphPasses Failed";
155       return status;
156     }
157   }
158 
159   status = QuantTransform(param, graph_defT_);
160   if (status != RET_OK && status != RET_NO_CHANGE) {
161     return status;
162   }
163 
164   {
165     Optimizer nested_loop_optimizer;
166     auto old_nodes = GetGraphNodes(*graph_defT_);
167     nested_loop_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
168     nested_loop_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
169     nested_loop_optimizer.AddPass(new (std::nothrow) SubgraphTensorPass());
170     nested_loop_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
171     nested_loop_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
172     nested_loop_optimizer.AddPass(new (std::nothrow) ConstNodeReorderPass());
173     status = nested_loop_optimizer.Run(graph_defT_);
174     if (status != RET_OK && status != RET_NO_CHANGE) {
175       MS_LOG(ERROR) << "Run nested_loop_optimizer graphPasses Failed";
176       return status;
177     }
178   }
179 
180   {
181     Optimizer forming_model_optimizer;
182     forming_model_optimizer.AddPass(new (std::nothrow) InferShapePass(param->fmk_type));
183     forming_model_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass(param));
184     if (param->train_model) {
185       forming_model_optimizer.AddPass(new (std::nothrow) NodeNamePass());
186     }
187     forming_model_optimizer.AddPass(new (std::nothrow) TensorNamePass());
188     forming_model_optimizer.AddPass(new (std::nothrow) ConvertFP32ToFP16Pass(param->weight_fp16));
189     status = forming_model_optimizer.Run(graph_defT_);
190     if (status != RET_OK) {
191       MS_LOG(ERROR) << "Run InferShapeOptimizer graphPasses Failed.";
192       return status;
193     }
194   }
195   return RET_OK;
196 }
197 }  // namespace mindspore::lite
198