1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "coder/generator/component/train_component.h"
18 #include <string>
19 #include "coder/utils/type_cast.h"
20
21 namespace mindspore::lite::micro {
22
CodeTrainParams(std::ofstream & ofs)23 void CodeTrainParams(std::ofstream &ofs) {
24 ofs << "struct TrainParameter {\n"
25 " float beta1_;\n"
26 " float beta2_;\n"
27 " float epsilon_;\n"
28 "};\n"
29 "\n"
30 "enum EarlyStopType {\n"
31 " Diff = 0,\n"
32 " WeigthDiff = 1,\n"
33 " Abs = 2,\n"
34 "};\n"
35 "\n"
36 "struct EarlyStop {\n"
37 " enum EarlyStopType type;\n"
38 " float tolerate;\n"
39 "};\n\n";
40 }
41
CodeFeaturesState(std::ofstream & ofs)42 void CodeFeaturesState(std::ofstream &ofs) {
43 ofs << "/**\n"
44 " *\n"
45 " * @param size, return the number of features\n"
46 " * @return, the address of features\n"
47 " */\n"
48 << "FeatureParam *GetFeatures(int *size);\n\n";
49 ofs << "/**\n"
50 " *\n"
51 " * @param features, the address of features\n"
52 " * @param size, the number of features\n"
53 " * @return, status\n"
54 " */\n"
55 << "int UpdateFeatures(FeatureParam *features, int size);\n\n";
56 }
57
CodeFeaturesImplement(std::ofstream & ofs,const std::unique_ptr<CoderContext> & ctx)58 void CodeFeaturesImplement(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx) {
59 size_t features_num = 0;
60 ofs << "static FeatureParam feature_params[] = {\n";
61 for (const auto &item : ctx->saved_weights()) {
62 std::string addr = item.first;
63 Tensor *tensor = item.second;
64 if (tensor->tensor_name().empty()) {
65 MS_LOG(ERROR) << "exist empty feature";
66 continue;
67 }
68 ofs << "\t{\"" << tensor->tensor_name() << "\", " << addr << ", " << tensor->ElementsNum() << ", "
69 << EnumMicroTensorDataType(tensor->data_type()) << "}, \n";
70 features_num++;
71 }
72 ofs << "};\n";
73
74 ofs << "FeatureParam *GetFeatures(int *size) {\n"
75 << " *size = " << features_num << ";\n"
76 << " return feature_params;\n"
77 "}\n\n";
78
79 ofs << "int "
80 << "UpdateFeatures(FeatureParam *features, int size) {\n"
81 << " for (int i = 0; i < size; ++i) {\n"
82 " FeatureParam *src = features + i;\n"
83 " FeatureParam dst;\n"
84 " // find the dst feature\n"
85 " bool is_find = false;\n"
86 << " for (int j = 0; j < " << features_num << "; ++j) {\n"
87 << " if (strcmp(src->name, feature_params[j].name) == 0) {\n"
88 " dst = feature_params[j];\n"
89 " is_find = true;\n"
90 " break;\n"
91 " }\n"
92 " }\n"
93 " if (!is_find) {\n"
94 " MICRO_ERROR(\"invalid feature param: %s\", src->name);\n"
95 " return RET_ERROR;\n"
96 " }\n"
97 " if (src->elenums != dst.elenums) {\n"
98 " MICRO_ERROR(\"feature %s elenums is mismatch, src: %lu, dst: %lu\", src->name, src->elenums, "
99 "dst.elenums);\n"
100 " return RET_ERROR;\n"
101 " }\n"
102 " memcpy(dst.data, src->data, src->elenums * sizeof(float));\n"
103 " }\n"
104 " MICRO_INFO(\"update features map success\");\n"
105 " return RET_OK;\n"
106 "}\n\n";
107 }
108
CodeTrainState(std::ofstream & ofs)109 void CodeTrainState(std::ofstream &ofs) {
110 ofs
111 << "/**\n"
112 " * Train Function\n"
113 " * @param epoch, the train epoch\n"
114 " * @param iterations, which is equal to batch_num, the number of iterations of each epoch\n"
115 " * @param use_train_param, default parameters already exists, such as the momentum, user can update these\n"
116 " * parameters to improve the accuracy\n"
117 " * @param parameter, the TrainParameter contains epsilon/beta1/beta2\n"
118 " * @return status\n"
119 " */\n"
120 << "int Train(const int epoch, const int iterations, bool use_train_param, const struct TrainParameter *parameter, "
121 "const struct EarlyStop *early_stop);\n\n";
122 }
123
CodeTrainImplement(std::ofstream & ofs,const std::unique_ptr<CoderContext> & ctx)124 void CodeTrainImplement(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx) {
125 std::vector<Tensor *> inputs = ctx->graph_inputs();
126 size_t inputs_num = inputs.size();
127 auto inputs_tostring = [&inputs, &ctx]() {
128 std::string result;
129 result += "{";
130 for (size_t i = 0; i < inputs.size(); ++i) {
131 result += ctx->input_name() + std::to_string(i) + ", ";
132 }
133 result += "}";
134 return result;
135 };
136 auto wrap = [](size_t i) { return "[" + std::to_string(i) + "]"; };
137 auto offset_inputs = [&inputs, &wrap]() {
138 std::string src = "origin_inputs";
139 std::string dst = "input_ptr";
140 std::string result;
141 for (size_t i = 0; i < inputs.size(); ++i) {
142 result += dst + wrap(i) += " = " + src + wrap(i) + " + j * " + std::to_string(inputs[i]->Size()) + ";\n";
143 }
144 return result;
145 };
146 auto varify_inputs = [&inputs, &wrap]() {
147 std::string result;
148 for (size_t i = 0; i < inputs.size(); ++i) {
149 result += "origin_input" + wrap(i) + " + iterations * " + std::to_string(inputs[i]->Size()) + " == NULL";
150 i < inputs.size() - 1 ? result += " || " : result += "";
151 }
152 return result;
153 };
154 ofs << "int Train(const int epoch, const int iterations, bool use_train_param, const struct TrainParameter "
155 "*parameter, const struct EarlyStop *early_stop) {\n"
156 " if (iterations <= 0 || epoch <= 0) {\n"
157 " MICRO_ERROR(\"error iterations or epoch!, epoch:%d, iterations:%d\", epoch, iterations);\n"
158 " return RET_ERROR;\n"
159 " }\n"
160 " MICRO_INFO(\"train epoch: %d, batch_num: %d\", epoch, iterations);\n"
161 << " const void *origin_input[] = " << inputs_tostring() << ";\n";
162 ofs << " if (" << varify_inputs() << ") {\n"
163 << " MICRO_ERROR(\"input data is invalid, epoch: %d, iterations: %d\", epoch, iterations);\n"
164 " return RET_ERROR;\n"
165 " }\n";
166 ofs << " for (int i = 0; i < epoch; ++i) {\n"
167 << " const void *input_ptr[" << inputs_num << "];\n"
168 << " float loss = 0;\n"
169 << " for (int j = 0; j < iterations; ++j) {\n"
170 << " " << offset_inputs() << "\n"
171 << " "
172 << "_SetInputs(input_ptr, " << inputs_num << ");\n"
173 << " "
174 << "_Inference();\n"
175 << " loss = "
176 << "ComputeLossAndGradient();\n"
177 << " }\n"
178 " }\n"
179 " return RET_OK;\n"
180 "};\n\n";
181 }
182 } // namespace mindspore::lite::micro
183