1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/train/transfer_session.h"
18 #include <sys/stat.h>
19 #include <algorithm>
20 #include <utility>
21 #include <vector>
22 #include <iostream>
23 #include <fstream>
24 #include <memory>
25 #include "include/errorcode.h"
26 #include "src/common/utils.h"
27 #include "src/common/file_utils.h"
28 #include "src/tensor.h"
29 #include "src/train/loss_kernel.h"
30 #include "src/train/optimizer_kernel.h"
31 #include "src/train/train_populate_parameter.h"
32 #include "src/litert/executor.h"
33 #include "src/train/train_export.h"
34 #include "src/train/train_utils.h"
35
36 namespace mindspore {
37 namespace lite {
TransferSession(const char * model_buf_backbone,size_t size_backbone,const std::shared_ptr<lite::InnerContext> & context)38 TransferSession::TransferSession(const char *model_buf_backbone, size_t size_backbone,
39 const std::shared_ptr<lite::InnerContext> &context)
40 : is_valid_(false) {
41 lite_model_ = reinterpret_cast<char *>(malloc(size_backbone));
42 size_backbone_ = size_backbone;
43 if (lite_model_ != nullptr) {
44 std::copy(model_buf_backbone, model_buf_backbone + size_backbone, lite_model_);
45 backbone_session_ = LiteSession::CreateSession(lite_model_, size_backbone, context);
46 if (backbone_session_ != nullptr) {
47 is_valid_ = true;
48 } else {
49 MS_LOG(ERROR) << "transfer session: create backbone session failed";
50 }
51 }
52 }
53
GetInputs() const54 std::vector<lite::Tensor *> TransferSession::GetInputs() const { return combined_inputs_; }
55
CompileFormatTransform(lite::Tensor * out,lite::Tensor * in,int * mask,size_t mask_len)56 bool TransferSession::CompileFormatTransform(lite::Tensor *out, lite::Tensor *in, int *mask, size_t mask_len) {
57 MS_ASSERT(out->shape().size() == mask_len);
58 for (std::size_t dim = 0; dim != out->shape().size(); ++dim) {
59 if (in->shape().at(mask[dim]) != out->shape().at(dim)) {
60 return false;
61 }
62 }
63 return true;
64 }
65
CompileTransferGraph()66 int TransferSession::CompileTransferGraph() {
67 combined_inputs_ = backbone_session_->GetInputs();
68 auto outputs_backbone = backbone_session_->GetOutputs();
69 auto inputs_head = lite::TrainSession::GetInputs();
70
71 int ret = RET_OK;
72 for (auto input : inputs_head) {
73 bool match = false;
74 mindspore::lite::Tensor *output = nullptr;
75 for (auto it = outputs_backbone.begin(); it != outputs_backbone.end(); ++it) {
76 output = it->second;
77 if (output->ElementsNum() == input->ElementsNum() && output->shape().size() == input->shape().size()) {
78 match = true;
79 for (std::size_t dim = 0; dim != output->shape().size(); ++dim) {
80 if (input->shape().at(dim) != output->shape().at(dim)) {
81 match = false;
82 break;
83 }
84 }
85 if (match == false && input->shape().size() == 4) {
86 int nchw2nhwc_mask[4] = {0, 3, 1, 2};
87 nchw2nhwc_ = CompileFormatTransform(output, input, nchw2nhwc_mask, 4);
88 match = nchw2nhwc_;
89 }
90 if (match) {
91 break;
92 }
93 }
94 }
95 if (match) {
96 backbone_head_map_.push_back(std::make_pair(input, output));
97 } else {
98 combined_inputs_.push_back(input);
99 }
100 }
101 if (backbone_head_map_.size() == 0) {
102 ret = RET_ERROR;
103 }
104 return ret;
105 }
106
GetInputsByTensorName(const std::string & tensor_name) const107 mindspore::lite::Tensor *TransferSession::GetInputsByTensorName(const std::string &tensor_name) const {
108 /* First look in backbone netwok */
109 auto ret = backbone_session_->GetInputsByTensorName(tensor_name);
110 /* If not found look in head network */
111 if (ret == nullptr) {
112 ret = TrainSession::GetInputsByTensorName(tensor_name);
113 }
114 return ret;
115 }
116
~TransferSession()117 TransferSession::~TransferSession() {
118 if (backbone_session_ != nullptr) {
119 delete backbone_session_;
120 backbone_session_ = nullptr;
121 }
122 if (lite_model_ != nullptr) {
123 free(lite_model_);
124 lite_model_ = nullptr;
125 }
126 }
BindThread(bool if_bind)127 void TransferSession::BindThread(bool if_bind) {
128 backbone_session_->BindThread(if_bind);
129 TrainSession::BindThread(if_bind);
130 }
131
RunGraph(const KernelCallBack & before,const KernelCallBack & after)132 int TransferSession::RunGraph(const KernelCallBack &before, const KernelCallBack &after) {
133 auto ret = backbone_session_->RunGraph(before, after);
134 if (ret != RET_OK) {
135 return ret;
136 }
137 for (auto &backbone_head_pair : backbone_head_map_) {
138 auto input = backbone_head_pair.first;
139 auto output = backbone_head_pair.second;
140 float *input_data = reinterpret_cast<float *>(input->MutableData());
141 float *output_data = reinterpret_cast<float *>(output->MutableData());
142 if (nchw2nhwc_) {
143 int batch = input->shape().at(0);
144 int plane = input->shape().at(1) * input->shape().at(2);
145 int channel = input->shape().at(3);
146 int img_size = plane * channel;
147 for (int b = 0; b < batch; b++) {
148 float *in = input_data + b * img_size;
149 float *out = output_data + b * img_size;
150 for (int p = 0; p < plane; p++) {
151 for (int c = 0; c < channel; c++) {
152 in[p * channel + c] = out[c * plane + p];
153 }
154 }
155 }
156 } else {
157 std::copy(output_data, output_data + output->ElementsNum(), input_data);
158 }
159 }
160 ret = lite::TrainSession::RunGraph(before, after);
161 return ret;
162 }
163
ConnectionMap()164 std::unordered_map<size_t, size_t> TransferSession::ConnectionMap() {
165 std::unordered_map<size_t, size_t> map;
166 for (auto &backbone_head_pair : backbone_head_map_) {
167 auto input = backbone_head_pair.first;
168 auto output = backbone_head_pair.second;
169 auto in_id = TSFindTensorByName(tensors_, input->tensor_name());
170 if (in_id == tensors_.size()) {
171 MS_LOG(ERROR) << "cannot find input tensor " << input->tensor_name();
172 map.clear();
173 return map;
174 }
175 auto out_id = TSFindTensorByName(backbone_session_->tensors_, output->tensor_name());
176 if (out_id == backbone_session_->tensors_.size()) {
177 MS_LOG(ERROR) << "cannot find input tensor " << output->tensor_name();
178 map.clear();
179 return map;
180 }
181 map[in_id] = out_id;
182 }
183 return map;
184 }
185
186 template <typename DestType>
ExportInner(DestType destination,ModelType model_type,QuantizationType quant_type,FormatType format,std::vector<std::string> out_put_tensor_name)187 int TransferSession::ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type,
188 FormatType format, std::vector<std::string> out_put_tensor_name) {
189 if constexpr (std::is_same_v<DestType, const std::string &>) {
190 MS_CHECK_FALSE_MSG(destination.empty(), RET_ERROR, "File name cannot be empty");
191 } else if constexpr (std::is_same_v<DestType, Buffer *>) {
192 MS_CHECK_FALSE_MSG(destination == nullptr, RET_ERROR, "model buffer cannot be nullptr");
193 } else {
194 MS_LOG(ERROR) << "Unsupported destination.";
195 return RET_ERROR;
196 }
197 if (format != FT_FLATBUFFERS) {
198 MS_LOG(ERROR) << "Currently only flatbuffer format is supported";
199 return RET_ERROR;
200 }
201
202 if (model_type == MT_TRAIN) {
203 return TrainSession::Export(destination, model_type, quant_type, format);
204 }
205
206 bool orig_train_state = IsTrain();
207 if (Eval() != RET_OK) {
208 MS_LOG(ERROR) << "eval failed.";
209 return RET_ERROR;
210 }
211 TrainExport texport(destination);
212 int status = texport.LoadModel(lite_model_, size_backbone_);
213 if (status != RET_OK) {
214 MS_LOG(ERROR) << "cannot init export";
215 return status;
216 }
217 auto connect_map = ConnectionMap();
218 texport.set_connect(connect_map);
219 if (nchw2nhwc_) {
220 status = texport.AddTransformNode();
221 if (status != RET_OK) {
222 MS_LOG(ERROR) << "cannot add transform node";
223 return status;
224 }
225 }
226 if (!out_put_tensor_name.empty() && model_type == MT_INFERENCE) {
227 std::vector<kernel::KernelExec *> export_kernels = {};
228 status = FindExportKernels(&export_kernels, out_put_tensor_name, inference_kernels_);
229 if (status != RET_OK) {
230 MS_LOG(ERROR) << "FindExportKernels failed.";
231 return RET_ERROR;
232 }
233 status = texport.ExportNet(export_kernels, tensors_, {}, out_put_tensor_name, model_.get(), quant_type,
234 backbone_session_->model_);
235 } else {
236 status = texport.ExportNet(inference_kernels_, tensors_, {}, GetOutputTensorNames(), model_.get(), quant_type,
237 backbone_session_->model_);
238 }
239 if (status != RET_OK) {
240 MS_LOG(ERROR) << "cannot serialize head";
241 return status;
242 }
243
244 if constexpr (std::is_same_v<DestType, const std::string &>) {
245 status = texport.SaveToFile();
246 if (status != RET_OK) {
247 MS_LOG(ERROR) << "failed to save to " << destination;
248 return status;
249 }
250 } else {
251 status = texport.SaveToBuffer();
252 MS_CHECK_FALSE_MSG(status != RET_OK, status, "fail to save to model buffer.");
253 }
254
255 if (orig_train_state) {
256 auto ret = Train();
257 if (ret != RET_OK) {
258 MS_LOG(ERROR) << "train failed.";
259 return RET_ERROR;
260 }
261 }
262 return status;
263 }
264
Export(const std::string & filename,ModelType model_type,QuantizationType quant_type,FormatType format,std::vector<std::string> out_put_tensor_name)265 int TransferSession::Export(const std::string &filename, ModelType model_type, QuantizationType quant_type,
266 FormatType format, std::vector<std::string> out_put_tensor_name) {
267 return ExportInner<const std::string &>(filename, model_type, quant_type, format, out_put_tensor_name);
268 }
269
Export(Buffer * model_buffer,ModelType model_type,QuantizationType quant_type,FormatType format,std::vector<std::string> out_put_tensor_name)270 int TransferSession::Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType format,
271 std::vector<std::string> out_put_tensor_name) {
272 return ExportInner<Buffer *>(model_buffer, model_type, quant_type, format, out_put_tensor_name);
273 }
274
CreateTransferSessionInt(const char * model_buf_backbone,size_t size_backbone,const char * model_buf_head,size_t size_head,const std::shared_ptr<InnerContext> & context,bool train_mode,const lite::TrainCfg * cfg)275 lite::LiteSession *CreateTransferSessionInt(const char *model_buf_backbone, size_t size_backbone,
276 const char *model_buf_head, size_t size_head,
277 const std::shared_ptr<InnerContext> &context, bool train_mode,
278 const lite::TrainCfg *cfg) {
279 auto ValidModelSize = [](size_t size) -> bool {
280 constexpr size_t MaxModelSize = 1024 * 1024 * 1024ULL; // 1G B
281 return size < MaxModelSize && size > 0;
282 };
283 if (!ValidModelSize(size_backbone)) {
284 MS_LOG(ERROR) << "size_backbone too large: " << size_backbone;
285 return nullptr;
286 }
287 if (!ValidModelSize(size_head)) {
288 MS_LOG(ERROR) << "size_head too large: " << size_head;
289 return nullptr;
290 }
291 auto session = new (std::nothrow) lite::TransferSession(model_buf_backbone, size_backbone, context);
292 if (session == nullptr) {
293 MS_LOG(ERROR) << "create transfer session failed";
294 return nullptr;
295 }
296 if (!session->is_valid()) {
297 MS_LOG(ERROR) << "create transfer session failed";
298 delete session;
299 return nullptr;
300 }
301 auto ret = session->TrainInit(context, cfg);
302 if (ret != lite::RET_OK) {
303 MS_LOG(ERROR) << "init transfer session failed";
304 delete session;
305 return nullptr;
306 }
307
308 auto model = std::shared_ptr<lite::Model>(lite::Model::Import(model_buf_head, size_head));
309 if (model == nullptr) {
310 MS_LOG(ERROR) << "create model for head train session failed";
311 delete session;
312 return nullptr;
313 }
314
315 ret = session->CompileTrainGraph(model);
316 if (ret != lite::RET_OK) {
317 MS_LOG(ERROR) << "Compiling Train Graph failed";
318 delete session;
319 return nullptr;
320 }
321 ret = session->CompileTransferGraph();
322 if (ret != lite::RET_OK) {
323 MS_LOG(ERROR) << "Compiling Transfer Graph failed";
324 delete session;
325 return nullptr;
326 }
327
328 if (train_mode) {
329 ret = session->Train();
330 } else {
331 ret = session->Eval();
332 }
333 if (ret != lite::RET_OK) {
334 MS_LOG(ERROR) << "Could not switch to Train Mode " << train_mode;
335 delete session;
336 return nullptr;
337 }
338 return session;
339 }
340 } // namespace lite
341 } // namespace mindspore
342