• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "coder/generator/generator.h"
17 #include <sys/stat.h>
18 #include <set>
19 #include <fstream>
20 #include "coder/generator/component/component.h"
21 #include "coder/generator/component/cmake_component.h"
22 #include "coder/generator/component/weight_component.h"
23 #include "coder/generator/component/common_component.h"
24 #include "coder/generator/component/const_blocks/cmake_lists.h"
25 #include "coder/generator/component/const_blocks/debug_utils.h"
26 #include "coder/generator/component/const_blocks/load_input.h"
27 #include "coder/generator/component/const_blocks/calib_output.h"
28 #include "coder/generator/component/const_blocks/msession.h"
29 #include "coder/generator/component/const_blocks/mtensor.h"
30 #include "coder/generator/component/const_blocks/mmodel.h"
31 #include "coder/generator/component/const_blocks/thread_pool.h"
32 #include "coder/generator/component/const_blocks/benchmark.h"
33 #include "coder/generator/component/const_blocks/license.h"
34 #include "coder/log.h"
35 #include "coder/opcoders/parallel.h"
36 #include "coder/opcoders/kernel_registry.h"
37 
38 namespace mindspore::lite::micro {
WriteContentToFile(const std::string & file,const std::string & content)39 int WriteContentToFile(const std::string &file, const std::string &content) {
40   std::ofstream of(file);
41   if (of.bad()) {
42     MS_LOG(ERROR) << "open file error " << file;
43     return RET_ERROR;
44   }
45   MS_LOG(INFO) << "write " << file;
46   of << content;
47   of.close();
48   return RET_OK;
49 }
50 
Generator(std::unique_ptr<CoderContext> ctx)51 Generator::Generator(std::unique_ptr<CoderContext> ctx) {
52   ctx_ = std::move(ctx);
53   this->config_ = Configurator::GetInstance();
54   this->net_inc_hfile_ = "net.h";
55   this->net_src_cfile_ = "net.c";
56   this->net_weight_hfile_ = "weight.h";
57   this->net_src_file_path_ = config_->code_path() + kSourcePath;
58   this->net_main_file_path_ = config_->code_path() + kBenchmarkPath;
59   origin_umask_ = umask(user_umask_);
60   MS_LOG(DEBUG) << "origin umask: " << origin_umask_ << ", user umask: " << user_umask_;
61 }
62 
~Generator()63 Generator::~Generator() { (void)umask(origin_umask_); }
64 
CodeNetRunFunc(std::ofstream & ofs)65 void Generator::CodeNetRunFunc(std::ofstream &ofs) {
66   // generate net inference code
67   ofs << "void Inference() {\n";
68   if (config_->support_parallel()) {
69     ofs << gThreadNum << " = GetCurrentThreadNum();\n ";
70   }
71   for (const auto &block : ctx_->code_blocks()) {
72     ofs << "  {\n" << block << "  }\n";
73   }
74   ofs << "}\n";
75 }
76 
CodeSourceCMakeFile()77 int Generator::CodeSourceCMakeFile() {
78   std::string src_cmake_file = net_src_file_path_ + cmake_file_name_;
79   std::ofstream ofs(src_cmake_file);
80   MS_CHECK_TRUE(!ofs.bad(), "filed to open file");
81   MS_LOG(INFO) << "write " << src_cmake_file;
82   CodeCMakeNetLibrary(ofs, ctx_, config_);
83   ofs.close();
84   return RET_OK;
85 }
86 
CodeStaticContent()87 int Generator::CodeStaticContent() {
88   std::vector<std::pair<std::string, std::string>> const_blocks = {
89     {config_->code_path() + "/" + "CMakeLists.txt", bench_cmake_lists_txt},
90     {net_main_file_path_ + "calib_output.h", calib_header},
91     {net_main_file_path_ + "calib_output.cc", calib_source},
92     {net_main_file_path_ + "load_input.h", load_input_h},
93     {net_main_file_path_ + "load_input.c", load_input_c},
94     {net_main_file_path_ + "benchmark.cc", benchmark_source},
95     {net_src_file_path_ + "CMakeLists.txt", src_cmake_lists_txt},
96     {net_src_file_path_ + "session.h", session_header},
97     {net_src_file_path_ + "tensor.h", tensor_header},
98     {net_src_file_path_ + "tensor.cc", tensor_source},
99     {net_src_file_path_ + "mmodel.h", model_header}};
100 
101   if (config_->support_parallel()) {
102     const_blocks.emplace_back(std::make_pair(net_src_file_path_ + kThreadWrapper, thread_header));
103   }
104   if (config_->debug_mode()) {
105     const_blocks.emplace_back(std::make_pair(net_src_file_path_ + "debug_utils.h", debug_utils_h));
106     const_blocks.emplace_back(std::make_pair(net_src_file_path_ + "debug_utils.c", debug_utils_c));
107   }
108   for (const auto &static_block : const_blocks) {
109     std::string file_name = static_block.first;
110     std::string content = static_block.second;
111     MS_CHECK_RET_CODE(WriteContentToFile(file_name, content), "write file failed");
112   }
113   return RET_OK;
114 }
115 
CodeSessionImplement()116 int Generator::CodeSessionImplement() {
117   std::string cfile = net_src_file_path_ + "session.cc";
118   std::ofstream ofs(cfile);
119   MS_CHECK_TRUE(!ofs.bad(), "filed to open file");
120   MS_LOG(INFO) << "write " << cfile;
121   ofs << g_hwLicense;
122   ofs << "#include \"session.h\"\n";
123   ofs << "#include \"mmodel.h\"\n";
124   ofs << "#include \"net.h\"\n";
125   ofs << "#include <new>\n\n";
126   if (config_->support_parallel()) {
127     ofs << "#include \"" << kThreadWrapper << "\"\n";
128   }
129   CodeSessionCompileGraph(ofs, ctx_, config_);
130   CodeCreateSessionDestructor(ofs, config_);
131   ofs << session_source;
132   CodeCreateSessionImplement(ofs, config_);
133   return RET_OK;
134 }
135 
CodeWeightFile()136 int Generator::CodeWeightFile() {
137   // weight header file
138   std::string hfile = net_src_file_path_ + net_weight_hfile_;
139   std::ofstream hofs(hfile);
140   MS_CHECK_TRUE(!hofs.bad(), "filed to open file");
141   MS_LOG(INFO) << "write " << hfile;
142   CodeWeightFileHeader(hofs, ctx_);
143 
144   // weight source file
145   std::string cfile = net_src_file_path_ + "weight.c";
146   std::ofstream cofs(cfile);
147   MS_CHECK_TRUE(!cofs.bad(), "filed to open file");
148   MS_LOG(INFO) << "write " << cfile;
149   cofs << g_hwLicense;
150   cofs << "#include \"" << net_weight_hfile_ << "\"\n\n";
151   cofs << "int  " << gThreadNum << " = 1; \n";
152   cofs << "unsigned char * " << ctx_->buffer_name() << " = 0; \n";
153 
154   if (config_->target() != kARM32M) {
155     std::string net_file = net_src_file_path_ + "net.bin";
156     SaveDataToNet(ctx_->saved_weights(), net_file);
157     CodeModelParamsForNet(hofs, cofs, ctx_);
158     CodeWeightInitFunc(cofs, ctx_);
159   } else {
160     CodeModelParamsState(hofs, ctx_->saved_weights());
161     CodeModelParamsData(cofs, ctx_->saved_weights());
162   }
163   hofs.close();
164   cofs.close();
165   return RET_OK;
166 }
167 
CodeRegKernelHFile()168 int Generator::CodeRegKernelHFile() {
169   if (!KernelRegistry::GetInstance()->HasKernelRegistered()) return RET_OK;
170   if (!KernelRegistry::GetInstance()->CheckRegistered(schema::PrimitiveType_Custom)) {
171     MS_LOG(ERROR) << "Only support custom kernel to register now!";
172     return RET_ERROR;
173   }
174 
175   std::string reg_kernel_header = net_src_file_path_ + "registered_kernel.h";
176   std::ofstream cofs(reg_kernel_header);
177   MS_CHECK_TRUE(!cofs.bad(), "filed to open file");
178   MS_LOG(INFO) << "write " << reg_kernel_header;
179   cofs << g_hwLicense;
180   cofs << "#include \"nnacl/tensor_c.h\"\n";
181   cofs << "#include \"nnacl/custom_parameter.h\"\n\n";
182   cofs << KernelRegistry::GetInstance()->GenKernelInterface(kCustomKernelName, kCustomKernelParam) << "\n";
183   return RET_OK;
184 }
185 
GenerateCode()186 int Generator::GenerateCode() {
187   MS_CHECK_RET_CODE(CodeNetHFile(), "code net h file failed.");
188   MS_CHECK_RET_CODE(CodeNetCFile(), "code net c file failed.");
189   MS_CHECK_RET_CODE(CodeWeightFile(), "code weight file failed.");
190   MS_CHECK_RET_CODE(CodeSourceCMakeFile(), "code net cmake file failed.");
191   MS_CHECK_RET_CODE(CodeStaticContent(), "code static content failed.");
192   MS_CHECK_RET_CODE(CodeSessionImplement(), "code session file failed.");
193   MS_CHECK_RET_CODE(CodeRegKernelHFile(), "code registered kernel header file failed.");
194   return RET_OK;
195 }
196 }  // namespace mindspore::lite::micro
197