• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "coder/generator/train/train_generator.h"
18 #include <vector>
19 #include <string>
20 #include "coder/generator/component/common_component.h"
21 #include "coder/generator/component/weight_component.h"
22 #include "coder/generator/component/train_component.h"
23 #include "coder/generator/component/const_blocks/license.h"
24 
25 namespace mindspore::lite::micro {
CodeGradientFunc(std::ofstream & ofs) const26 void TrainGenerator::CodeGradientFunc(std::ofstream &ofs) const {
27   ofs << "float ComputeLossAndGradient() {\n";
28   ofs << "  float loss = 0;\n";
29   for (const auto &block : ctx_->train_blocks()) {
30     ofs << "\t{\n" << block << "\t}\n";
31   }
32   ofs << "  return loss;\n";
33   ofs << "}\n";
34 }
35 
CodeNetHFile()36 int TrainGenerator::CodeNetHFile() {
37   std::string net_include_file = net_src_file_path_ + net_inc_hfile_;
38   std::ofstream ofs(net_include_file);
39   MS_CHECK_TRUE(!ofs.bad(), "filed to open file");
40   MS_LOG(INFO) << "write " << net_include_file;
41   ofs << g_hwLicense;
42   if (config_->code_mode() == CodeMode::Inference) {
43     ofs << "#include \"src/runtime/thread_pool.h\"\n";
44   }
45   ofs << "#include \"microtensor.h\"\n\n";
46   CodeTrainParams(ofs);
47   CodeInputState(ofs);
48   if (config_->target() != kARM32M) {
49     CodeInitWeightState(ofs);
50   }
51   CodeManageResourceState(ofs);
52   CodeInferenceState(ofs);
53   CodeFeaturesState(ofs);
54   CodeTrainState(ofs);
55   return RET_OK;
56 }
57 
CodeNetCFile()58 int TrainGenerator::CodeNetCFile() {
59   std::string net_impl_file = net_src_file_path_ + net_src_cfile_;
60   std::ofstream ofs(net_impl_file);
61   MS_CHECK_TRUE(!ofs.bad(), "filed to open file");
62   MS_LOG(INFO) << "write " << net_impl_file;
63   CodeInputImplement(ofs, ctx_);
64   CodeInitResourceImplement(ofs, ctx_);
65   CodeFreeResourceImplement(ofs, ctx_);
66   CodeFeaturesImplement(ofs, ctx_);
67   CodeNetRunFunc(ofs);
68   CodeGradientFunc(ofs);
69   CodeTrainImplement(ofs, ctx_);
70   ofs.close();
71   return RET_OK;
72 }
73 }  // namespace mindspore::lite::micro
74