• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "coder/coder.h"
17 #include <iomanip>
18 #include <string>
19 #include <vector>
20 #include <map>
21 #include "schema/inner/model_generated.h"
22 #include "tools/common/flag_parser.h"
23 #include "coder/session.h"
24 #include "coder/context.h"
25 #include "utils/dir_utils.h"
26 #include "securec/include/securec.h"
27 #include "src/common/file_utils.h"
28 #include "src/common/utils.h"
29 #include "coder/config.h"
30 #include "coder/generator/component/component.h"
31 
32 namespace mindspore::lite::micro {
33 class CoderFlags : public virtual FlagParser {
34  public:
CoderFlags()35   CoderFlags() {
36     AddFlag(&CoderFlags::model_path_, "modelPath", "Input model path", "");
37     AddFlag(&CoderFlags::code_path_, "codePath", "Input code path", ".");
38     AddFlag(&CoderFlags::target_, "target", "generated code target, x86| ARM32M| ARM32A| ARM64", "x86");
39     AddFlag(&CoderFlags::code_mode_, "codeMode", "generated code mode, Inference | Train", "Inference");
40     AddFlag(&CoderFlags::support_parallel_, "supportParallel", "whether support parallel launch, true | false", false);
41     AddFlag(&CoderFlags::debug_mode_, "debugMode", "dump the tensors data for debugging, true | false", false);
42   }
43 
44   ~CoderFlags() override = default;
45 
46   std::string model_path_;
47   bool support_parallel_{false};
48   std::string code_path_;
49   std::string code_mode_;
50   bool debug_mode_{false};
51   std::string target_;
52 };
53 
Run(const std::string & model_path)54 int Coder::Run(const std::string &model_path) {
55   session_ = CreateCoderSession();
56   if (session_ == nullptr) {
57     MS_LOG(ERROR) << "new session failed while running!";
58     return RET_ERROR;
59   }
60   STATUS status = session_->Init(model_path);
61   if (status != RET_OK) {
62     MS_LOG(ERROR) << "Init session failed!";
63     return RET_ERROR;
64   }
65 
66   status = session_->Build();
67   if (status != RET_OK) {
68     MS_LOG(ERROR) << "Compile graph failed!";
69     return status;
70   }
71   status = session_->Run();
72   if (status != RET_OK) {
73     MS_LOG(ERROR) << "Generate Code Files error!" << status;
74     return status;
75   }
76   status = session_->GenerateCode();
77   if (status != RET_OK) {
78     MS_LOG(ERROR) << "Generate Code Files error!" << status;
79   }
80   return status;
81 }
82 
ParseProjDir(std::string model_path)83 int Configurator::ParseProjDir(std::string model_path) {
84   // split model_path to get model file name
85   proj_dir_ = model_path;
86   size_t found = proj_dir_.find_last_of("/\\");
87   if (found != std::string::npos) {
88     proj_dir_ = proj_dir_.substr(found + 1);
89   }
90   found = proj_dir_.find(".ms");
91   if (found != std::string::npos) {
92     proj_dir_ = proj_dir_.substr(0, found);
93   } else {
94     MS_LOG(ERROR) << "model file's name must be end with \".ms\".";
95     return RET_ERROR;
96   }
97   if (proj_dir_.size() == 0) {
98     proj_dir_ = "net";
99     MS_LOG(WARNING) << "parse model's name failed, use \"net\" instead.";
100   }
101   return RET_OK;
102 }
103 
Init(const CoderFlags & flags) const104 int Coder::Init(const CoderFlags &flags) const {
105   static const std::map<std::string, Target> kTargetMap = {
106     {"x86", kX86}, {"ARM32M", kARM32M}, {"ARM32A", kARM32A}, {"ARM64", kARM64}, {"All", kAllTargets}};
107   static const std::map<std::string, CodeMode> kCodeModeMap = {{"Inference", Inference}, {"Train", Train}};
108   Configurator *config = Configurator::GetInstance();
109 
110   std::vector<std::function<bool()>> parsers;
111   parsers.emplace_back([&flags, config]() -> bool {
112     if (!FileExists(flags.model_path_)) {
113       MS_LOG(ERROR) << "model_path \"" << flags.model_path_ << "\" is not valid";
114       return false;
115     }
116     if (config->ParseProjDir(flags.model_path_) != RET_OK) {
117       return false;
118     }
119     return true;
120   });
121 
122   parsers.emplace_back([&flags, config]() -> bool {
123     auto target_item = kTargetMap.find(flags.target_);
124     MS_CHECK_TRUE_RET_BOOL(target_item != kTargetMap.end(), "unsupported target: " + flags.target_);
125     config->set_target(target_item->second);
126     return true;
127   });
128 
129   parsers.emplace_back([&flags, config]() -> bool {
130     auto code_item = kCodeModeMap.find(flags.code_mode_);
131     MS_CHECK_TRUE_RET_BOOL(code_item != kCodeModeMap.end(), "unsupported code mode: " + flags.code_mode_);
132     config->set_code_mode(code_item->second);
133     return true;
134   });
135 
136   parsers.emplace_back([&flags, config]() -> bool {
137     if (flags.support_parallel_ && config->target() == kARM32M) {
138       MS_LOG(ERROR) << "arm32M cannot support parallel.";
139       return false;
140     }
141     config->set_support_parallel(flags.support_parallel_);
142     return true;
143   });
144 
145   parsers.emplace_back([&flags, config]() -> bool {
146     config->set_debug_mode(flags.debug_mode_);
147     return true;
148   });
149 
150   parsers.emplace_back([&flags, config]() -> bool {
151     const std::string slash = std::string(kSlash);
152     if (!flags.code_path_.empty() && !DirExists(flags.code_path_)) {
153       MS_LOG(ERROR) << "code_gen code path " << flags.code_path_ << " is not valid";
154       return false;
155     }
156     config->set_code_path(flags.code_path_);
157     if (flags.code_path_.empty()) {
158       std::string path = ".." + slash + config->proj_dir();
159       config->set_code_path(path);
160     } else {
161       if (flags.code_path_.substr(flags.code_path_.size() - 1, 1) != slash) {
162         std::string path = flags.code_path_ + slash + config->proj_dir();
163         config->set_code_path(path);
164       } else {
165         std::string path = flags.code_path_ + config->proj_dir();
166         config->set_code_path(path);
167       }
168     }
169     return InitProjDirs(flags.code_path_, config->proj_dir()) != RET_ERROR;
170   });
171 
172   if (!std::all_of(parsers.begin(), parsers.end(), [](auto &parser) -> bool { return parser(); })) {
173     if (!flags.help) {
174       std::cerr << flags.Usage() << std::endl;
175       return 0;
176     }
177     return RET_ERROR;
178   }
179   auto print_parameter = [](auto name, auto value) {
180     MS_LOG(INFO) << std::setw(20) << std::left << name << "= " << value;
181   };
182 
183   print_parameter("modelPath", flags.model_path_);
184   print_parameter("projectName", config->proj_dir());
185   print_parameter("target", config->target());
186   print_parameter("codePath", config->code_path());
187   print_parameter("codeMode", config->code_mode());
188   print_parameter("debugMode", config->debug_mode());
189 
190   return RET_OK;
191 }
192 
RunCoder(int argc,const char ** argv)193 int RunCoder(int argc, const char **argv) {
194   CoderFlags flags;
195   Option<std::string> err = flags.ParseFlags(argc, argv, false, false);
196   if (err.IsSome()) {
197     std::cerr << err.Get() << std::endl;
198     std::cerr << flags.Usage() << std::endl;
199     return RET_ERROR;
200   }
201 
202   if (flags.help) {
203     std::cerr << flags.Usage() << std::endl;
204     return RET_OK;
205   }
206 
207   Coder code_gen;
208   STATUS status = code_gen.Init(flags);
209   if (status != RET_OK) {
210     MS_LOG(ERROR) << "Coder init Error";
211     return status;
212   }
213   status = code_gen.Run(flags.model_path_);
214   if (status != RET_OK) {
215     MS_LOG(ERROR) << "Coder Run Error.";
216     return status;
217   }
218   MS_LOG(INFO) << "end of Coder";
219   return RET_OK;
220 }
221 }  // namespace mindspore::lite::micro
222