• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/import/mindspore_importer.h"
18 #include <memory>
19 #include <map>
20 #include <set>
21 #include <vector>
22 #include <regex>
23 #include "tools/converter/parser/parser_utils.h"
24 #include "tools/converter/import/primitive_adjust.h"
25 #include "tools/converter/import/mindir_adjust.h"
26 #include "tools/converter/import/mindir_control_flow_adjust.h"
27 #include "tools/optimizer/common/gllo_utils.h"
28 #include "tools/common/tensor_util.h"
29 #include "tools/converter/parser/unify_format.h"
30 #include "tools/converter/parser/lstm_adjust_pass.h"
31 #include "nnacl/op_base.h"
32 
33 namespace mindspore::lite {
34 namespace {
35 constexpr size_t kConvWeightIndex = 2;
36 }  // namespace
Mindir2AnfAdjust(const FuncGraphPtr & func_graph,const converter::Flags & flag)37 STATUS MindsporeImporter::Mindir2AnfAdjust(const FuncGraphPtr &func_graph, const converter::Flags &flag) {
38   MS_ASSERT(func_graph != nullptr);
39   auto primitive_adjust_pass = std::make_shared<PrimitiveAdjust>();
40   MS_CHECK_TRUE_MSG(primitive_adjust_pass != nullptr, RET_NULL_PTR, "primitive_adjust_pass is nullptr.");
41   primitive_adjust_pass->SetFmkType(flag.fmk);
42   if (!primitive_adjust_pass->Run(func_graph)) {
43     MS_LOG(ERROR) << "primitive adjust failed.";
44     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
45     return RET_ERROR;
46   }
47   auto mindir_adjust_pass = std::make_shared<MindirAdjust>();
48   MS_CHECK_TRUE_MSG(mindir_adjust_pass != nullptr, RET_NULL_PTR, "mindir_adjust_pass is nullptr.");
49   mindir_adjust_pass->SetFmkType(flag.fmk);
50   mindir_adjust_pass->SetTrainFlag(flag.trainModel);
51   if (!mindir_adjust_pass->Run(func_graph)) {
52     MS_LOG(ERROR) << "MindIr adjust failed.";
53     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
54     return RET_ERROR;
55   }
56   auto mindir_control_flow_adjust = std::make_shared<MindIRControlFlowAdjust>();
57   MS_CHECK_TRUE_MSG(mindir_control_flow_adjust != nullptr, RET_NULL_PTR, "mindir_control_flow_adjust is nullptr.");
58   mindir_control_flow_adjust->SetFmkType(flag.fmk);
59   if (!mindir_control_flow_adjust->Run(func_graph)) {
60     MS_LOG(ERROR) << "MindIR control flow adjust failed.";
61     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
62     return RET_ERROR;
63   }
64   return RET_OK;
65 }
66 
Hex2ByteArray(const std::string & hex_str,unsigned char * byte_array,size_t max_len)67 size_t MindsporeImporter::Hex2ByteArray(const std::string &hex_str, unsigned char *byte_array, size_t max_len) {
68   std::regex r("[0-9a-fA-F]+");
69   if (!std::regex_match(hex_str, r)) {
70     MS_LOG(ERROR) << "Some characters of dec_key not in [0-9a-fA-F]";
71     return 0;
72   }
73   if (hex_str.size() % 2 == 1) {  // Mod 2 determines whether it is odd
74     MS_LOG(ERROR) << "the hexadecimal dec_key length must be even";
75     return 0;
76   }
77   size_t byte_len = hex_str.size() / 2;  // Two hexadecimal characters represent a byte
78   if (byte_len > max_len) {
79     MS_LOG(ERROR) << "the hexadecimal dec_key length exceeds the maximum limit: 64";
80     return 0;
81   }
82   constexpr int32_t a_val = 10;  // The value of 'A' in hexadecimal is 10
83   constexpr size_t half_byte_offset = 4;
84   for (size_t i = 0; i < byte_len; ++i) {
85     size_t p = i * 2;  // The i-th byte is represented by the 2*i and 2*i+1 hexadecimal characters
86     if (hex_str[p] >= 'a' && hex_str[p] <= 'f') {
87       byte_array[i] = hex_str[p] - 'a' + a_val;
88     } else if (hex_str[p] >= 'A' && hex_str[p] <= 'F') {
89       byte_array[i] = hex_str[p] - 'A' + a_val;
90     } else {
91       byte_array[i] = hex_str[p] - '0';
92     }
93     if (hex_str[p + 1] >= 'a' && hex_str[p + 1] <= 'f') {
94       byte_array[i] = (byte_array[i] << half_byte_offset) | (hex_str[p + 1] - 'a' + a_val);
95     } else if (hex_str[p] >= 'A' && hex_str[p] <= 'F') {
96       byte_array[i] = (byte_array[i] << half_byte_offset) | (hex_str[p + 1] - 'A' + a_val);
97     } else {
98       byte_array[i] = (byte_array[i] << half_byte_offset) | (hex_str[p + 1] - '0');
99     }
100   }
101   return byte_len;
102 }
103 
ProcessDependCnode(const CNodePtr & cnode)104 STATUS MindsporeImporter::ProcessDependCnode(const CNodePtr &cnode) {
105   MS_ASSERT(cnode != nullptr);
106   if (!opt::CheckPrimitiveType(cnode, prim::kPrimDepend)) {
107     output_tensor_name_.push_back(cnode->fullname_with_scope());
108     return RET_NO_CHANGE;
109   }
110   auto depend_input = cnode->input(1);
111   MS_CHECK_TRUE_MSG(depend_input != nullptr, RET_ERROR, "depend_input is nullptr");
112   if (utils::isa<CNodePtr>(depend_input)) {
113     auto depend_input_cnode = utils::cast<CNodePtr>(depend_input);
114     auto status = ProcessDependCnode(depend_input_cnode);
115     if (status == RET_NO_CHANGE) {
116       return RET_OK;
117     }
118   } else if (utils::isa<ParameterPtr>(depend_input) || utils::isa<ValueNode>(depend_input)) {
119     output_tensor_name_.push_back(depend_input->fullname_with_scope());
120   }
121   return RET_OK;
122 }
123 
GetFuncGraphOutputName(const CNodePtr & return_node)124 STATUS MindsporeImporter::GetFuncGraphOutputName(const CNodePtr &return_node) {
125   MS_ASSERT(return_node != nullptr);
126   for (size_t i = 0; i < return_node->inputs().size(); i++) {
127     auto output_node = return_node->input(i);
128     if (output_node == nullptr) {
129       MS_LOG(ERROR) << "output_node is nullptr.";
130       return RET_ERROR;
131     } else if (output_node->isa<mindspore::CNode>()) {
132       if (opt::CheckPrimitiveType(output_node, prim::kPrimUpdateState) ||
133           opt::CheckPrimitiveType(output_node, prim::kPrimLoad)) {
134         continue;
135       }
136       auto output_cnode = utils::cast<CNodePtr>(output_node);
137       if (opt::CheckPrimitiveType(output_node, prim::kPrimMakeTuple)) {
138         for (size_t j = 0; j < output_cnode->inputs().size(); j++) {
139           auto tuple_input = output_cnode->input(j);
140           MS_CHECK_TRUE_MSG(tuple_input != nullptr, RET_ERROR, "tuple_input is nullptr");
141           if (!utils::isa<CNodePtr>(tuple_input)) {
142             continue;
143           }
144           auto tuple_input_cnode = utils::cast<CNodePtr>(tuple_input);
145           if (opt::CheckPrimitiveType(output_node, prim::kPrimUpdateState) ||
146               opt::CheckPrimitiveType(output_node, prim::kPrimLoad)) {
147             continue;
148           }
149           auto status = ProcessDependCnode(tuple_input_cnode);
150           if (status != RET_OK && status != RET_NO_CHANGE) {
151             MS_LOG(ERROR) << "ProcessDependCnode failed.";
152           }
153         }
154       } else if (opt::CheckPrimitiveType(output_node, prim::kPrimDepend)) {
155         auto status = ProcessDependCnode(output_cnode);
156         if (status != RET_OK && status != RET_NO_CHANGE) {
157           MS_LOG(ERROR) << "ProcessDependCnode failed.";
158         }
159       } else {
160         output_tensor_name_.push_back(output_cnode->fullname_with_scope());
161       }
162     }
163   }
164   return RET_OK;
165 }
166 
RemoveUnusedGraphInput(const FuncGraphPtr & func_graph)167 STATUS MindsporeImporter::RemoveUnusedGraphInput(const FuncGraphPtr &func_graph) {
168   MS_CHECK_TRUE_MSG(func_graph != nullptr, RET_ERROR, "func_graph is nullptr");
169   std::map<AnfNodePtr, bool> graph_input_map;
170   for (auto &input : func_graph->get_inputs()) {
171     graph_input_map[input] = false;
172   }
173   auto node_list = TopoSort(func_graph->get_return());
174   for (auto &node : node_list) {
175     if (!utils::isa<CNode>(node)) {
176       continue;
177     }
178     auto cnode = node->cast<CNodePtr>();
179     for (size_t i = 0; i < cnode->inputs().size(); i++) {
180       for (auto &input : func_graph->get_inputs()) {
181         if (input == cnode->input(i) && graph_input_map.count(input) == 1) {
182           graph_input_map[input] = true;
183         }
184       }
185     }
186   }
187   for (auto &item : graph_input_map) {
188     if (item.second == false) {
189       func_graph->DropNode(item.first);
190     }
191   }
192   return RET_OK;
193 }
194 
ImportMindIR(const converter::Flags & flag)195 FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
196   FuncGraphPtr func_graph;
197   if (flag.dec_key.size() != 0) {
198     unsigned char key[32];
199     const size_t key_len = Hex2ByteArray(flag.dec_key, key, 32);
200     if (key_len == 0) {
201       return nullptr;
202     }
203     func_graph = LoadMindIR(flag.modelFile, false, key, key_len, flag.dec_mode);
204     auto ret = memset_s(key, sizeof(key), 0, key_len);
205     if (ret != 0) {
206       MS_LOG(EXCEPTION) << "memset_s error";
207     }
208   } else {
209     func_graph = LoadMindIR(flag.modelFile);
210   }
211   if (func_graph == nullptr) {
212     MS_LOG(ERROR) << "get funcGraph failed for fmk:MINDIR";
213     MS_LOG(ERROR)
214       << "The model maybe an old model, Please download the package whose version is before 1.2 and then try again.";
215     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
216     return nullptr;
217   }
218   func_graph->set_attr("graph_name", MakeValue("main_graph"));
219   func_graph->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeMs)));
220   auto status = RemoveUnusedGraphInput(func_graph);
221   if (status != RET_OK) {
222     MS_LOG(ERROR) << "RemoveUnusedGraphInput failed.";
223     return nullptr;
224   }
225   status = GetFuncGraphOutputName(func_graph->get_return());
226   if (status != RET_OK) {
227     MS_LOG(ERROR) << "GetFuncGraphOutputName failed.";
228     return nullptr;
229   }
230   if (output_tensor_name_.empty()) {
231     MS_LOG(ERROR) << "Can not find output name.";
232     return nullptr;
233   }
234   ConverterContext::GetInstance()->SetGraphOutputTensorNames(output_tensor_name_);
235 #ifdef ENABLE_LITE_ACL
236   MS_LOG(INFO) << "There is no need to adjust and pass graph when in Ascend310.";
237   return func_graph;
238 #endif
239   if ((status = Mindir2AnfAdjust(func_graph, flag)) != RET_OK) {
240     MS_LOG(ERROR) << "Mindir2AnfAdjust failed.";
241     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
242     return nullptr;
243   }
244   auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeMs, flag.trainModel);
245   MS_CHECK_TRUE_MSG(unify_format != nullptr, nullptr, "unify_format is nullptr.");
246   if (!unify_format->Run(func_graph)) {
247     MS_LOG(ERROR) << "Run insert transpose failed.";
248     return nullptr;
249   }
250 
251   auto lstm_adjust_pass = std::make_shared<opt::LstmAdjustPass>();
252   MS_CHECK_TRUE_MSG(lstm_adjust_pass != nullptr, nullptr, "lstm_adjust_pass is nullptr.");
253   if (!lstm_adjust_pass->Run(func_graph)) {
254     MS_LOG(ERROR) << "Run mindir lstm adjust failed.";
255     return nullptr;
256   }
257   return func_graph;
258 }
259 }  // namespace mindspore::lite
260