1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "coder/generator/train/train_generator.h"
18 #include <vector>
19 #include <string>
20 #include "coder/generator/component/common_component.h"
21 #include "coder/generator/component/weight_component.h"
22 #include "coder/generator/component/train_component.h"
23 #include "coder/generator/component/const_blocks/license.h"
24
25 namespace mindspore::lite::micro {
CodeGradientFunc(std::ofstream & ofs) const26 void TrainGenerator::CodeGradientFunc(std::ofstream &ofs) const {
27 ofs << "float ComputeLossAndGradient() {\n";
28 ofs << " float loss = 0;\n";
29 for (const auto &block : ctx_->train_blocks()) {
30 ofs << "\t{\n" << block << "\t}\n";
31 }
32 ofs << " return loss;\n";
33 ofs << "}\n";
34 }
35
CodeNetHFile()36 int TrainGenerator::CodeNetHFile() {
37 std::string net_include_file = net_src_file_path_ + net_inc_hfile_;
38 std::ofstream ofs(net_include_file);
39 MS_CHECK_TRUE(!ofs.bad(), "filed to open file");
40 MS_LOG(INFO) << "write " << net_include_file;
41 ofs << g_hwLicense;
42 if (config_->code_mode() == CodeMode::Inference) {
43 ofs << "#include \"src/runtime/thread_pool.h\"\n";
44 }
45 ofs << "#include \"microtensor.h\"\n\n";
46 CodeTrainParams(ofs);
47 CodeInputState(ofs);
48 if (config_->target() != kARM32M) {
49 CodeInitWeightState(ofs);
50 }
51 CodeManageResourceState(ofs);
52 CodeInferenceState(ofs);
53 CodeFeaturesState(ofs);
54 CodeTrainState(ofs);
55 return RET_OK;
56 }
57
CodeNetCFile()58 int TrainGenerator::CodeNetCFile() {
59 std::string net_impl_file = net_src_file_path_ + net_src_cfile_;
60 std::ofstream ofs(net_impl_file);
61 MS_CHECK_TRUE(!ofs.bad(), "filed to open file");
62 MS_LOG(INFO) << "write " << net_impl_file;
63 CodeInputImplement(ofs, ctx_);
64 CodeInitResourceImplement(ofs, ctx_);
65 CodeFreeResourceImplement(ofs, ctx_);
66 CodeFeaturesImplement(ofs, ctx_);
67 CodeNetRunFunc(ofs);
68 CodeGradientFunc(ofs);
69 CodeTrainImplement(ofs, ctx_);
70 ofs.close();
71 return RET_OK;
72 }
73 } // namespace mindspore::lite::micro
74