1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/converter/import/mindspore_importer.h"
18 #include <memory>
19 #include <map>
20 #include <set>
21 #include <vector>
22 #include <regex>
23 #include "tools/converter/parser/parser_utils.h"
24 #include "tools/converter/import/primitive_adjust.h"
25 #include "tools/converter/import/mindir_adjust.h"
26 #include "tools/converter/import/mindir_control_flow_adjust.h"
27 #include "tools/optimizer/common/gllo_utils.h"
28 #include "tools/common/tensor_util.h"
29 #include "tools/converter/parser/unify_format.h"
30 #include "tools/converter/parser/lstm_adjust_pass.h"
31 #include "nnacl/op_base.h"
32
33 namespace mindspore::lite {
34 namespace {
35 constexpr size_t kConvWeightIndex = 2;
36 } // namespace
Mindir2AnfAdjust(const FuncGraphPtr & func_graph,const converter::Flags & flag)37 STATUS MindsporeImporter::Mindir2AnfAdjust(const FuncGraphPtr &func_graph, const converter::Flags &flag) {
38 MS_ASSERT(func_graph != nullptr);
39 auto primitive_adjust_pass = std::make_shared<PrimitiveAdjust>();
40 MS_CHECK_TRUE_MSG(primitive_adjust_pass != nullptr, RET_NULL_PTR, "primitive_adjust_pass is nullptr.");
41 primitive_adjust_pass->SetFmkType(flag.fmk);
42 if (!primitive_adjust_pass->Run(func_graph)) {
43 MS_LOG(ERROR) << "primitive adjust failed.";
44 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
45 return RET_ERROR;
46 }
47 auto mindir_adjust_pass = std::make_shared<MindirAdjust>();
48 MS_CHECK_TRUE_MSG(mindir_adjust_pass != nullptr, RET_NULL_PTR, "mindir_adjust_pass is nullptr.");
49 mindir_adjust_pass->SetFmkType(flag.fmk);
50 mindir_adjust_pass->SetTrainFlag(flag.trainModel);
51 if (!mindir_adjust_pass->Run(func_graph)) {
52 MS_LOG(ERROR) << "MindIr adjust failed.";
53 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
54 return RET_ERROR;
55 }
56 auto mindir_control_flow_adjust = std::make_shared<MindIRControlFlowAdjust>();
57 MS_CHECK_TRUE_MSG(mindir_control_flow_adjust != nullptr, RET_NULL_PTR, "mindir_control_flow_adjust is nullptr.");
58 mindir_control_flow_adjust->SetFmkType(flag.fmk);
59 if (!mindir_control_flow_adjust->Run(func_graph)) {
60 MS_LOG(ERROR) << "MindIR control flow adjust failed.";
61 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
62 return RET_ERROR;
63 }
64 return RET_OK;
65 }
66
Hex2ByteArray(const std::string & hex_str,unsigned char * byte_array,size_t max_len)67 size_t MindsporeImporter::Hex2ByteArray(const std::string &hex_str, unsigned char *byte_array, size_t max_len) {
68 std::regex r("[0-9a-fA-F]+");
69 if (!std::regex_match(hex_str, r)) {
70 MS_LOG(ERROR) << "Some characters of dec_key not in [0-9a-fA-F]";
71 return 0;
72 }
73 if (hex_str.size() % 2 == 1) { // Mod 2 determines whether it is odd
74 MS_LOG(ERROR) << "the hexadecimal dec_key length must be even";
75 return 0;
76 }
77 size_t byte_len = hex_str.size() / 2; // Two hexadecimal characters represent a byte
78 if (byte_len > max_len) {
79 MS_LOG(ERROR) << "the hexadecimal dec_key length exceeds the maximum limit: 64";
80 return 0;
81 }
82 constexpr int32_t a_val = 10; // The value of 'A' in hexadecimal is 10
83 constexpr size_t half_byte_offset = 4;
84 for (size_t i = 0; i < byte_len; ++i) {
85 size_t p = i * 2; // The i-th byte is represented by the 2*i and 2*i+1 hexadecimal characters
86 if (hex_str[p] >= 'a' && hex_str[p] <= 'f') {
87 byte_array[i] = hex_str[p] - 'a' + a_val;
88 } else if (hex_str[p] >= 'A' && hex_str[p] <= 'F') {
89 byte_array[i] = hex_str[p] - 'A' + a_val;
90 } else {
91 byte_array[i] = hex_str[p] - '0';
92 }
93 if (hex_str[p + 1] >= 'a' && hex_str[p + 1] <= 'f') {
94 byte_array[i] = (byte_array[i] << half_byte_offset) | (hex_str[p + 1] - 'a' + a_val);
95 } else if (hex_str[p] >= 'A' && hex_str[p] <= 'F') {
96 byte_array[i] = (byte_array[i] << half_byte_offset) | (hex_str[p + 1] - 'A' + a_val);
97 } else {
98 byte_array[i] = (byte_array[i] << half_byte_offset) | (hex_str[p + 1] - '0');
99 }
100 }
101 return byte_len;
102 }
103
ProcessDependCnode(const CNodePtr & cnode)104 STATUS MindsporeImporter::ProcessDependCnode(const CNodePtr &cnode) {
105 MS_ASSERT(cnode != nullptr);
106 if (!opt::CheckPrimitiveType(cnode, prim::kPrimDepend)) {
107 output_tensor_name_.push_back(cnode->fullname_with_scope());
108 return RET_NO_CHANGE;
109 }
110 auto depend_input = cnode->input(1);
111 MS_CHECK_TRUE_MSG(depend_input != nullptr, RET_ERROR, "depend_input is nullptr");
112 if (utils::isa<CNodePtr>(depend_input)) {
113 auto depend_input_cnode = utils::cast<CNodePtr>(depend_input);
114 auto status = ProcessDependCnode(depend_input_cnode);
115 if (status == RET_NO_CHANGE) {
116 return RET_OK;
117 }
118 } else if (utils::isa<ParameterPtr>(depend_input) || utils::isa<ValueNode>(depend_input)) {
119 output_tensor_name_.push_back(depend_input->fullname_with_scope());
120 }
121 return RET_OK;
122 }
123
GetFuncGraphOutputName(const CNodePtr & return_node)124 STATUS MindsporeImporter::GetFuncGraphOutputName(const CNodePtr &return_node) {
125 MS_ASSERT(return_node != nullptr);
126 for (size_t i = 0; i < return_node->inputs().size(); i++) {
127 auto output_node = return_node->input(i);
128 if (output_node == nullptr) {
129 MS_LOG(ERROR) << "output_node is nullptr.";
130 return RET_ERROR;
131 } else if (output_node->isa<mindspore::CNode>()) {
132 if (opt::CheckPrimitiveType(output_node, prim::kPrimUpdateState) ||
133 opt::CheckPrimitiveType(output_node, prim::kPrimLoad)) {
134 continue;
135 }
136 auto output_cnode = utils::cast<CNodePtr>(output_node);
137 if (opt::CheckPrimitiveType(output_node, prim::kPrimMakeTuple)) {
138 for (size_t j = 0; j < output_cnode->inputs().size(); j++) {
139 auto tuple_input = output_cnode->input(j);
140 MS_CHECK_TRUE_MSG(tuple_input != nullptr, RET_ERROR, "tuple_input is nullptr");
141 if (!utils::isa<CNodePtr>(tuple_input)) {
142 continue;
143 }
144 auto tuple_input_cnode = utils::cast<CNodePtr>(tuple_input);
145 if (opt::CheckPrimitiveType(output_node, prim::kPrimUpdateState) ||
146 opt::CheckPrimitiveType(output_node, prim::kPrimLoad)) {
147 continue;
148 }
149 auto status = ProcessDependCnode(tuple_input_cnode);
150 if (status != RET_OK && status != RET_NO_CHANGE) {
151 MS_LOG(ERROR) << "ProcessDependCnode failed.";
152 }
153 }
154 } else if (opt::CheckPrimitiveType(output_node, prim::kPrimDepend)) {
155 auto status = ProcessDependCnode(output_cnode);
156 if (status != RET_OK && status != RET_NO_CHANGE) {
157 MS_LOG(ERROR) << "ProcessDependCnode failed.";
158 }
159 } else {
160 output_tensor_name_.push_back(output_cnode->fullname_with_scope());
161 }
162 }
163 }
164 return RET_OK;
165 }
166
RemoveUnusedGraphInput(const FuncGraphPtr & func_graph)167 STATUS MindsporeImporter::RemoveUnusedGraphInput(const FuncGraphPtr &func_graph) {
168 MS_CHECK_TRUE_MSG(func_graph != nullptr, RET_ERROR, "func_graph is nullptr");
169 std::map<AnfNodePtr, bool> graph_input_map;
170 for (auto &input : func_graph->get_inputs()) {
171 graph_input_map[input] = false;
172 }
173 auto node_list = TopoSort(func_graph->get_return());
174 for (auto &node : node_list) {
175 if (!utils::isa<CNode>(node)) {
176 continue;
177 }
178 auto cnode = node->cast<CNodePtr>();
179 for (size_t i = 0; i < cnode->inputs().size(); i++) {
180 for (auto &input : func_graph->get_inputs()) {
181 if (input == cnode->input(i) && graph_input_map.count(input) == 1) {
182 graph_input_map[input] = true;
183 }
184 }
185 }
186 }
187 for (auto &item : graph_input_map) {
188 if (item.second == false) {
189 func_graph->DropNode(item.first);
190 }
191 }
192 return RET_OK;
193 }
194
ImportMindIR(const converter::Flags & flag)195 FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
196 FuncGraphPtr func_graph;
197 if (flag.dec_key.size() != 0) {
198 unsigned char key[32];
199 const size_t key_len = Hex2ByteArray(flag.dec_key, key, 32);
200 if (key_len == 0) {
201 return nullptr;
202 }
203 func_graph = LoadMindIR(flag.modelFile, false, key, key_len, flag.dec_mode);
204 auto ret = memset_s(key, sizeof(key), 0, key_len);
205 if (ret != 0) {
206 MS_LOG(EXCEPTION) << "memset_s error";
207 }
208 } else {
209 func_graph = LoadMindIR(flag.modelFile);
210 }
211 if (func_graph == nullptr) {
212 MS_LOG(ERROR) << "get funcGraph failed for fmk:MINDIR";
213 MS_LOG(ERROR)
214 << "The model maybe an old model, Please download the package whose version is before 1.2 and then try again.";
215 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
216 return nullptr;
217 }
218 func_graph->set_attr("graph_name", MakeValue("main_graph"));
219 func_graph->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeMs)));
220 auto status = RemoveUnusedGraphInput(func_graph);
221 if (status != RET_OK) {
222 MS_LOG(ERROR) << "RemoveUnusedGraphInput failed.";
223 return nullptr;
224 }
225 status = GetFuncGraphOutputName(func_graph->get_return());
226 if (status != RET_OK) {
227 MS_LOG(ERROR) << "GetFuncGraphOutputName failed.";
228 return nullptr;
229 }
230 if (output_tensor_name_.empty()) {
231 MS_LOG(ERROR) << "Can not find output name.";
232 return nullptr;
233 }
234 ConverterContext::GetInstance()->SetGraphOutputTensorNames(output_tensor_name_);
235 #ifdef ENABLE_LITE_ACL
236 MS_LOG(INFO) << "There is no need to adjust and pass graph when in Ascend310.";
237 return func_graph;
238 #endif
239 if ((status = Mindir2AnfAdjust(func_graph, flag)) != RET_OK) {
240 MS_LOG(ERROR) << "Mindir2AnfAdjust failed.";
241 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
242 return nullptr;
243 }
244 auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeMs, flag.trainModel);
245 MS_CHECK_TRUE_MSG(unify_format != nullptr, nullptr, "unify_format is nullptr.");
246 if (!unify_format->Run(func_graph)) {
247 MS_LOG(ERROR) << "Run insert transpose failed.";
248 return nullptr;
249 }
250
251 auto lstm_adjust_pass = std::make_shared<opt::LstmAdjustPass>();
252 MS_CHECK_TRUE_MSG(lstm_adjust_pass != nullptr, nullptr, "lstm_adjust_pass is nullptr.");
253 if (!lstm_adjust_pass->Run(func_graph)) {
254 MS_LOG(ERROR) << "Run mindir lstm adjust failed.";
255 return nullptr;
256 }
257 return func_graph;
258 }
259 } // namespace mindspore::lite
260