1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "coder/generator/generator.h"
17 #include <sys/stat.h>
18 #include <set>
19 #include <fstream>
20 #include "coder/generator/component/component.h"
21 #include "coder/generator/component/cmake_component.h"
22 #include "coder/generator/component/weight_component.h"
23 #include "coder/generator/component/common_component.h"
24 #include "coder/generator/component/const_blocks/cmake_lists.h"
25 #include "coder/generator/component/const_blocks/debug_utils.h"
26 #include "coder/generator/component/const_blocks/load_input.h"
27 #include "coder/generator/component/const_blocks/calib_output.h"
28 #include "coder/generator/component/const_blocks/msession.h"
29 #include "coder/generator/component/const_blocks/mtensor.h"
30 #include "coder/generator/component/const_blocks/mmodel.h"
31 #include "coder/generator/component/const_blocks/thread_pool.h"
32 #include "coder/generator/component/const_blocks/benchmark.h"
33 #include "coder/generator/component/const_blocks/license.h"
34 #include "coder/log.h"
35 #include "coder/opcoders/parallel.h"
36 #include "coder/opcoders/kernel_registry.h"
37
38 namespace mindspore::lite::micro {
WriteContentToFile(const std::string & file,const std::string & content)39 int WriteContentToFile(const std::string &file, const std::string &content) {
40 std::ofstream of(file);
41 if (of.bad()) {
42 MS_LOG(ERROR) << "open file error " << file;
43 return RET_ERROR;
44 }
45 MS_LOG(INFO) << "write " << file;
46 of << content;
47 of.close();
48 return RET_OK;
49 }
50
Generator(std::unique_ptr<CoderContext> ctx)51 Generator::Generator(std::unique_ptr<CoderContext> ctx) {
52 ctx_ = std::move(ctx);
53 this->config_ = Configurator::GetInstance();
54 this->net_inc_hfile_ = "net.h";
55 this->net_src_cfile_ = "net.c";
56 this->net_weight_hfile_ = "weight.h";
57 this->net_src_file_path_ = config_->code_path() + kSourcePath;
58 this->net_main_file_path_ = config_->code_path() + kBenchmarkPath;
59 origin_umask_ = umask(user_umask_);
60 MS_LOG(DEBUG) << "origin umask: " << origin_umask_ << ", user umask: " << user_umask_;
61 }
62
~Generator()63 Generator::~Generator() { (void)umask(origin_umask_); }
64
CodeNetRunFunc(std::ofstream & ofs)65 void Generator::CodeNetRunFunc(std::ofstream &ofs) {
66 // generate net inference code
67 ofs << "void Inference() {\n";
68 if (config_->support_parallel()) {
69 ofs << gThreadNum << " = GetCurrentThreadNum();\n ";
70 }
71 for (const auto &block : ctx_->code_blocks()) {
72 ofs << " {\n" << block << " }\n";
73 }
74 ofs << "}\n";
75 }
76
CodeSourceCMakeFile()77 int Generator::CodeSourceCMakeFile() {
78 std::string src_cmake_file = net_src_file_path_ + cmake_file_name_;
79 std::ofstream ofs(src_cmake_file);
80 MS_CHECK_TRUE(!ofs.bad(), "filed to open file");
81 MS_LOG(INFO) << "write " << src_cmake_file;
82 CodeCMakeNetLibrary(ofs, ctx_, config_);
83 ofs.close();
84 return RET_OK;
85 }
86
CodeStaticContent()87 int Generator::CodeStaticContent() {
88 std::vector<std::pair<std::string, std::string>> const_blocks = {
89 {config_->code_path() + "/" + "CMakeLists.txt", bench_cmake_lists_txt},
90 {net_main_file_path_ + "calib_output.h", calib_header},
91 {net_main_file_path_ + "calib_output.cc", calib_source},
92 {net_main_file_path_ + "load_input.h", load_input_h},
93 {net_main_file_path_ + "load_input.c", load_input_c},
94 {net_main_file_path_ + "benchmark.cc", benchmark_source},
95 {net_src_file_path_ + "CMakeLists.txt", src_cmake_lists_txt},
96 {net_src_file_path_ + "session.h", session_header},
97 {net_src_file_path_ + "tensor.h", tensor_header},
98 {net_src_file_path_ + "tensor.cc", tensor_source},
99 {net_src_file_path_ + "mmodel.h", model_header}};
100
101 if (config_->support_parallel()) {
102 const_blocks.emplace_back(std::make_pair(net_src_file_path_ + kThreadWrapper, thread_header));
103 }
104 if (config_->debug_mode()) {
105 const_blocks.emplace_back(std::make_pair(net_src_file_path_ + "debug_utils.h", debug_utils_h));
106 const_blocks.emplace_back(std::make_pair(net_src_file_path_ + "debug_utils.c", debug_utils_c));
107 }
108 for (const auto &static_block : const_blocks) {
109 std::string file_name = static_block.first;
110 std::string content = static_block.second;
111 MS_CHECK_RET_CODE(WriteContentToFile(file_name, content), "write file failed");
112 }
113 return RET_OK;
114 }
115
CodeSessionImplement()116 int Generator::CodeSessionImplement() {
117 std::string cfile = net_src_file_path_ + "session.cc";
118 std::ofstream ofs(cfile);
119 MS_CHECK_TRUE(!ofs.bad(), "filed to open file");
120 MS_LOG(INFO) << "write " << cfile;
121 ofs << g_hwLicense;
122 ofs << "#include \"session.h\"\n";
123 ofs << "#include \"mmodel.h\"\n";
124 ofs << "#include \"net.h\"\n";
125 ofs << "#include <new>\n\n";
126 if (config_->support_parallel()) {
127 ofs << "#include \"" << kThreadWrapper << "\"\n";
128 }
129 CodeSessionCompileGraph(ofs, ctx_, config_);
130 CodeCreateSessionDestructor(ofs, config_);
131 ofs << session_source;
132 CodeCreateSessionImplement(ofs, config_);
133 return RET_OK;
134 }
135
CodeWeightFile()136 int Generator::CodeWeightFile() {
137 // weight header file
138 std::string hfile = net_src_file_path_ + net_weight_hfile_;
139 std::ofstream hofs(hfile);
140 MS_CHECK_TRUE(!hofs.bad(), "filed to open file");
141 MS_LOG(INFO) << "write " << hfile;
142 CodeWeightFileHeader(hofs, ctx_);
143
144 // weight source file
145 std::string cfile = net_src_file_path_ + "weight.c";
146 std::ofstream cofs(cfile);
147 MS_CHECK_TRUE(!cofs.bad(), "filed to open file");
148 MS_LOG(INFO) << "write " << cfile;
149 cofs << g_hwLicense;
150 cofs << "#include \"" << net_weight_hfile_ << "\"\n\n";
151 cofs << "int " << gThreadNum << " = 1; \n";
152 cofs << "unsigned char * " << ctx_->buffer_name() << " = 0; \n";
153
154 if (config_->target() != kARM32M) {
155 std::string net_file = net_src_file_path_ + "net.bin";
156 SaveDataToNet(ctx_->saved_weights(), net_file);
157 CodeModelParamsForNet(hofs, cofs, ctx_);
158 CodeWeightInitFunc(cofs, ctx_);
159 } else {
160 CodeModelParamsState(hofs, ctx_->saved_weights());
161 CodeModelParamsData(cofs, ctx_->saved_weights());
162 }
163 hofs.close();
164 cofs.close();
165 return RET_OK;
166 }
167
CodeRegKernelHFile()168 int Generator::CodeRegKernelHFile() {
169 if (!KernelRegistry::GetInstance()->HasKernelRegistered()) return RET_OK;
170 if (!KernelRegistry::GetInstance()->CheckRegistered(schema::PrimitiveType_Custom)) {
171 MS_LOG(ERROR) << "Only support custom kernel to register now!";
172 return RET_ERROR;
173 }
174
175 std::string reg_kernel_header = net_src_file_path_ + "registered_kernel.h";
176 std::ofstream cofs(reg_kernel_header);
177 MS_CHECK_TRUE(!cofs.bad(), "filed to open file");
178 MS_LOG(INFO) << "write " << reg_kernel_header;
179 cofs << g_hwLicense;
180 cofs << "#include \"nnacl/tensor_c.h\"\n";
181 cofs << "#include \"nnacl/custom_parameter.h\"\n\n";
182 cofs << KernelRegistry::GetInstance()->GenKernelInterface(kCustomKernelName, kCustomKernelParam) << "\n";
183 return RET_OK;
184 }
185
GenerateCode()186 int Generator::GenerateCode() {
187 MS_CHECK_RET_CODE(CodeNetHFile(), "code net h file failed.");
188 MS_CHECK_RET_CODE(CodeNetCFile(), "code net c file failed.");
189 MS_CHECK_RET_CODE(CodeWeightFile(), "code weight file failed.");
190 MS_CHECK_RET_CODE(CodeSourceCMakeFile(), "code net cmake file failed.");
191 MS_CHECK_RET_CODE(CodeStaticContent(), "code static content failed.");
192 MS_CHECK_RET_CODE(CodeSessionImplement(), "code session file failed.");
193 MS_CHECK_RET_CODE(CodeRegKernelHFile(), "code registered kernel header file failed.");
194 return RET_OK;
195 }
196 } // namespace mindspore::lite::micro
197