1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "coder/coder.h"
17 #include <iomanip>
18 #include <string>
19 #include <vector>
20 #include <map>
21 #include "schema/inner/model_generated.h"
22 #include "tools/common/flag_parser.h"
23 #include "coder/session.h"
24 #include "coder/context.h"
25 #include "utils/dir_utils.h"
26 #include "securec/include/securec.h"
27 #include "src/common/file_utils.h"
28 #include "src/common/utils.h"
29 #include "coder/config.h"
30 #include "coder/generator/component/component.h"
31
32 namespace mindspore::lite::micro {
33 class CoderFlags : public virtual FlagParser {
34 public:
CoderFlags()35 CoderFlags() {
36 AddFlag(&CoderFlags::model_path_, "modelPath", "Input model path", "");
37 AddFlag(&CoderFlags::code_path_, "codePath", "Input code path", ".");
38 AddFlag(&CoderFlags::target_, "target", "generated code target, x86| ARM32M| ARM32A| ARM64", "x86");
39 AddFlag(&CoderFlags::code_mode_, "codeMode", "generated code mode, Inference | Train", "Inference");
40 AddFlag(&CoderFlags::support_parallel_, "supportParallel", "whether support parallel launch, true | false", false);
41 AddFlag(&CoderFlags::debug_mode_, "debugMode", "dump the tensors data for debugging, true | false", false);
42 }
43
44 ~CoderFlags() override = default;
45
46 std::string model_path_;
47 bool support_parallel_{false};
48 std::string code_path_;
49 std::string code_mode_;
50 bool debug_mode_{false};
51 std::string target_;
52 };
53
Run(const std::string & model_path)54 int Coder::Run(const std::string &model_path) {
55 session_ = CreateCoderSession();
56 if (session_ == nullptr) {
57 MS_LOG(ERROR) << "new session failed while running!";
58 return RET_ERROR;
59 }
60 STATUS status = session_->Init(model_path);
61 if (status != RET_OK) {
62 MS_LOG(ERROR) << "Init session failed!";
63 return RET_ERROR;
64 }
65
66 status = session_->Build();
67 if (status != RET_OK) {
68 MS_LOG(ERROR) << "Compile graph failed!";
69 return status;
70 }
71 status = session_->Run();
72 if (status != RET_OK) {
73 MS_LOG(ERROR) << "Generate Code Files error!" << status;
74 return status;
75 }
76 status = session_->GenerateCode();
77 if (status != RET_OK) {
78 MS_LOG(ERROR) << "Generate Code Files error!" << status;
79 }
80 return status;
81 }
82
ParseProjDir(std::string model_path)83 int Configurator::ParseProjDir(std::string model_path) {
84 // split model_path to get model file name
85 proj_dir_ = model_path;
86 size_t found = proj_dir_.find_last_of("/\\");
87 if (found != std::string::npos) {
88 proj_dir_ = proj_dir_.substr(found + 1);
89 }
90 found = proj_dir_.find(".ms");
91 if (found != std::string::npos) {
92 proj_dir_ = proj_dir_.substr(0, found);
93 } else {
94 MS_LOG(ERROR) << "model file's name must be end with \".ms\".";
95 return RET_ERROR;
96 }
97 if (proj_dir_.size() == 0) {
98 proj_dir_ = "net";
99 MS_LOG(WARNING) << "parse model's name failed, use \"net\" instead.";
100 }
101 return RET_OK;
102 }
103
Init(const CoderFlags & flags) const104 int Coder::Init(const CoderFlags &flags) const {
105 static const std::map<std::string, Target> kTargetMap = {
106 {"x86", kX86}, {"ARM32M", kARM32M}, {"ARM32A", kARM32A}, {"ARM64", kARM64}, {"All", kAllTargets}};
107 static const std::map<std::string, CodeMode> kCodeModeMap = {{"Inference", Inference}, {"Train", Train}};
108 Configurator *config = Configurator::GetInstance();
109
110 std::vector<std::function<bool()>> parsers;
111 parsers.emplace_back([&flags, config]() -> bool {
112 if (!FileExists(flags.model_path_)) {
113 MS_LOG(ERROR) << "model_path \"" << flags.model_path_ << "\" is not valid";
114 return false;
115 }
116 if (config->ParseProjDir(flags.model_path_) != RET_OK) {
117 return false;
118 }
119 return true;
120 });
121
122 parsers.emplace_back([&flags, config]() -> bool {
123 auto target_item = kTargetMap.find(flags.target_);
124 MS_CHECK_TRUE_RET_BOOL(target_item != kTargetMap.end(), "unsupported target: " + flags.target_);
125 config->set_target(target_item->second);
126 return true;
127 });
128
129 parsers.emplace_back([&flags, config]() -> bool {
130 auto code_item = kCodeModeMap.find(flags.code_mode_);
131 MS_CHECK_TRUE_RET_BOOL(code_item != kCodeModeMap.end(), "unsupported code mode: " + flags.code_mode_);
132 config->set_code_mode(code_item->second);
133 return true;
134 });
135
136 parsers.emplace_back([&flags, config]() -> bool {
137 if (flags.support_parallel_ && config->target() == kARM32M) {
138 MS_LOG(ERROR) << "arm32M cannot support parallel.";
139 return false;
140 }
141 config->set_support_parallel(flags.support_parallel_);
142 return true;
143 });
144
145 parsers.emplace_back([&flags, config]() -> bool {
146 config->set_debug_mode(flags.debug_mode_);
147 return true;
148 });
149
150 parsers.emplace_back([&flags, config]() -> bool {
151 const std::string slash = std::string(kSlash);
152 if (!flags.code_path_.empty() && !DirExists(flags.code_path_)) {
153 MS_LOG(ERROR) << "code_gen code path " << flags.code_path_ << " is not valid";
154 return false;
155 }
156 config->set_code_path(flags.code_path_);
157 if (flags.code_path_.empty()) {
158 std::string path = ".." + slash + config->proj_dir();
159 config->set_code_path(path);
160 } else {
161 if (flags.code_path_.substr(flags.code_path_.size() - 1, 1) != slash) {
162 std::string path = flags.code_path_ + slash + config->proj_dir();
163 config->set_code_path(path);
164 } else {
165 std::string path = flags.code_path_ + config->proj_dir();
166 config->set_code_path(path);
167 }
168 }
169 return InitProjDirs(flags.code_path_, config->proj_dir()) != RET_ERROR;
170 });
171
172 if (!std::all_of(parsers.begin(), parsers.end(), [](auto &parser) -> bool { return parser(); })) {
173 if (!flags.help) {
174 std::cerr << flags.Usage() << std::endl;
175 return 0;
176 }
177 return RET_ERROR;
178 }
179 auto print_parameter = [](auto name, auto value) {
180 MS_LOG(INFO) << std::setw(20) << std::left << name << "= " << value;
181 };
182
183 print_parameter("modelPath", flags.model_path_);
184 print_parameter("projectName", config->proj_dir());
185 print_parameter("target", config->target());
186 print_parameter("codePath", config->code_path());
187 print_parameter("codeMode", config->code_mode());
188 print_parameter("debugMode", config->debug_mode());
189
190 return RET_OK;
191 }
192
RunCoder(int argc,const char ** argv)193 int RunCoder(int argc, const char **argv) {
194 CoderFlags flags;
195 Option<std::string> err = flags.ParseFlags(argc, argv, false, false);
196 if (err.IsSome()) {
197 std::cerr << err.Get() << std::endl;
198 std::cerr << flags.Usage() << std::endl;
199 return RET_ERROR;
200 }
201
202 if (flags.help) {
203 std::cerr << flags.Usage() << std::endl;
204 return RET_OK;
205 }
206
207 Coder code_gen;
208 STATUS status = code_gen.Init(flags);
209 if (status != RET_OK) {
210 MS_LOG(ERROR) << "Coder init Error";
211 return status;
212 }
213 status = code_gen.Run(flags.model_path_);
214 if (status != RET_OK) {
215 MS_LOG(ERROR) << "Coder Run Error.";
216 return status;
217 }
218 MS_LOG(INFO) << "end of Coder";
219 return RET_OK;
220 }
221 } // namespace mindspore::lite::micro
222