1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/train/transfer_session.h"
18 #include <sys/stat.h>
19 #include <algorithm>
20 #include <utility>
21 #include <vector>
22 #include <iostream>
23 #include <fstream>
24 #include <memory>
25 #include "include/errorcode.h"
26 #include "src/common/utils.h"
27 #include "src/common/file_utils.h"
28 #include "src/tensor.h"
29 #include "src/train/loss_kernel.h"
30 #include "src/train/optimizer_kernel.h"
31 #include "src/sub_graph_kernel.h"
32 #include "src/train/train_populate_parameter.h"
33 #include "src/executor.h"
34 #include "src/kernel_registry.h"
35 #include "src/runtime/kernel/arm/fp32_grad/convolution.h"
36 #include "nnacl/fp32/pack_fp32.h"
37 #include "src/train/train_export.h"
38 #include "src/train/train_utils.h"
39
40 namespace mindspore {
41 namespace lite {
TransferSession(const char * model_buf_backbone,size_t size_backbone,const lite::Context * context)42 TransferSession::TransferSession(const char *model_buf_backbone, size_t size_backbone, const lite::Context *context)
43 : is_valid_(false) {
44 lite_model_ = reinterpret_cast<char *>(malloc(size_backbone));
45 size_backbone_ = size_backbone;
46 if (lite_model_ != nullptr) {
47 std::copy(model_buf_backbone, model_buf_backbone + size_backbone, lite_model_);
48 backbone_session_ =
49 reinterpret_cast<lite::LiteSession *>(session::LiteSession::CreateSession(lite_model_, size_backbone, context));
50 if (backbone_session_ != nullptr) {
51 is_valid_ = true;
52 } else {
53 MS_LOG(ERROR) << "transfer session: create backbone session failed";
54 }
55 }
56 }
57
GetInputs() const58 std::vector<tensor::MSTensor *> TransferSession::GetInputs() const { return combined_inputs_; }
59
CompileFormatTransform(tensor::MSTensor * out,tensor::MSTensor * in,int * mask,size_t mask_len)60 bool TransferSession::CompileFormatTransform(tensor::MSTensor *out, tensor::MSTensor *in, int *mask, size_t mask_len) {
61 MS_ASSERT(out->shape().size() == mask_len);
62 for (std::size_t dim = 0; dim != out->shape().size(); ++dim) {
63 if (in->shape().at(mask[dim]) != out->shape().at(dim)) {
64 return false;
65 }
66 }
67 return true;
68 }
69
CompileTransferGraph()70 int TransferSession::CompileTransferGraph() {
71 combined_inputs_ = backbone_session_->GetInputs();
72 auto outputs_backbone = backbone_session_->GetOutputs();
73 auto inputs_head = lite::TrainSession::GetInputs();
74
75 int ret = RET_OK;
76 for (auto input : inputs_head) {
77 bool match = false;
78 mindspore::tensor::MSTensor *output = nullptr;
79 for (auto it = outputs_backbone.begin(); it != outputs_backbone.end(); ++it) {
80 output = it->second;
81 if (output->ElementsNum() == input->ElementsNum() && output->shape().size() == input->shape().size()) {
82 match = true;
83 for (std::size_t dim = 0; dim != output->shape().size(); ++dim) {
84 if (input->shape().at(dim) != output->shape().at(dim)) {
85 match = false;
86 break;
87 }
88 }
89 if (match == false && input->shape().size() == 4) {
90 int nchw2nhwc_mask[4] = {0, 3, 1, 2};
91 nchw2nhwc_ = CompileFormatTransform(output, input, nchw2nhwc_mask, 4);
92 match = nchw2nhwc_;
93 }
94 if (match) {
95 break;
96 }
97 }
98 }
99 if (match) {
100 backbone_head_map_.push_back(std::make_pair(input, output));
101 } else {
102 combined_inputs_.push_back(input);
103 }
104 }
105 if (backbone_head_map_.size() == 0) {
106 ret = RET_ERROR;
107 }
108 return ret;
109 }
110
GetInputsByTensorName(const std::string & tensor_name) const111 mindspore::tensor::MSTensor *TransferSession::GetInputsByTensorName(const std::string &tensor_name) const {
112 /* First look in backbone netwok */
113 auto ret = backbone_session_->GetInputsByTensorName(tensor_name);
114 /* If not found look in head network */
115 if (ret == nullptr) {
116 ret = TrainSession::GetInputsByTensorName(tensor_name);
117 }
118 return ret;
119 }
120
~TransferSession()121 TransferSession::~TransferSession() {
122 if (backbone_session_ != nullptr) {
123 delete backbone_session_;
124 backbone_session_ = nullptr;
125 }
126 if (lite_model_ != nullptr) {
127 free(lite_model_);
128 lite_model_ = nullptr;
129 }
130 }
BindThread(bool if_bind)131 void TransferSession::BindThread(bool if_bind) {
132 backbone_session_->BindThread(if_bind);
133 TrainSession::BindThread(if_bind);
134 }
135
RunGraph(const KernelCallBack & before,const KernelCallBack & after)136 int TransferSession::RunGraph(const KernelCallBack &before, const KernelCallBack &after) {
137 auto ret = backbone_session_->RunGraph(before, after);
138 if (ret != RET_OK) {
139 return ret;
140 }
141 for (auto &backbone_head_pair : backbone_head_map_) {
142 auto input = backbone_head_pair.first;
143 auto output = backbone_head_pair.second;
144 char *input_data = reinterpret_cast<char *>(input->MutableData());
145 char *output_data = reinterpret_cast<char *>(output->MutableData());
146 if (nchw2nhwc_) {
147 int plane = input->shape().at(1) * input->shape().at(2);
148 int batch = input->shape().at(0);
149 int channel = input->shape().at(3);
150 PackNCHWToNHWCFp32(output_data, input_data, batch, plane, channel, 0, 1);
151 } else {
152 std::copy(output_data, output_data + output->Size(), input_data);
153 }
154 }
155 ret = lite::TrainSession::RunGraph(before, after);
156 return ret;
157 }
158
ConnectionMap()159 std::unordered_map<size_t, size_t> TransferSession::ConnectionMap() {
160 std::unordered_map<size_t, size_t> map;
161 for (auto &backbone_head_pair : backbone_head_map_) {
162 auto input = backbone_head_pair.first;
163 auto output = backbone_head_pair.second;
164 auto in_id = TSFindTensorByName(tensors_, input->tensor_name());
165 if (in_id == tensors_.size()) {
166 MS_LOG(ERROR) << "cannot find input tensor " << input->tensor_name();
167 map.clear();
168 return map;
169 }
170 auto out_id = TSFindTensorByName(backbone_session_->tensors_, output->tensor_name());
171 if (out_id == backbone_session_->tensors_.size()) {
172 MS_LOG(ERROR) << "cannot find input tensor " << output->tensor_name();
173 map.clear();
174 return map;
175 }
176 map[in_id] = out_id;
177 }
178 return map;
179 }
180
Export(const std::string & filename,ModelType model_type,QuantizationType quant_type,FormatType format,std::vector<std::string> out_put_tensor_name)181 int TransferSession::Export(const std::string &filename, ModelType model_type, QuantizationType quant_type,
182 FormatType format, std::vector<std::string> out_put_tensor_name) {
183 if (format != FT_FLATBUFFERS) {
184 MS_LOG(ERROR) << "Currently only flatbuffer format is supported";
185 return RET_ERROR;
186 }
187
188 if (model_type == MT_TRAIN) {
189 return TrainSession::Export(filename, model_type, quant_type, format);
190 }
191
192 bool orig_train_state = IsTrain();
193 Eval();
194 TrainExport texport(filename);
195 int status = texport.LoadModel(lite_model_, size_backbone_);
196 if (status != RET_OK) {
197 MS_LOG(ERROR) << "cannot init export";
198 return status;
199 }
200 auto connect_map = ConnectionMap();
201 texport.set_connect(connect_map);
202 if (nchw2nhwc_) {
203 status = texport.AddTransformNode();
204 if (status != RET_OK) {
205 MS_LOG(ERROR) << "cannot add transform node";
206 return status;
207 }
208 }
209 if (!out_put_tensor_name.empty() && model_type == MT_INFERENCE) {
210 std::vector<kernel::LiteKernel *> export_kernels = {};
211 status = FindExportKernels(&export_kernels, out_put_tensor_name, inference_kernels_);
212 if (status != RET_OK) {
213 MS_LOG(ERROR) << "FindExportKernels failed.";
214 return RET_ERROR;
215 }
216 status = texport.ExportNet(export_kernels, tensors_, out_put_tensor_name, model_.get(), quant_type);
217 } else {
218 status = texport.ExportNet(inference_kernels_, tensors_, GetOutputTensorNames(), model_.get(), quant_type);
219 }
220 if (status != RET_OK) {
221 MS_LOG(ERROR) << "cannot serialize head";
222 return status;
223 }
224 status = texport.SaveToFile();
225 if (status != RET_OK) {
226 MS_LOG(ERROR) << "failed to save to " << filename;
227 return status;
228 }
229 if (orig_train_state) Train();
230 return status;
231 }
232 } // namespace lite
233
CreateTransferSessionInt(const char * model_buf_backbone,size_t size_backbone,const char * model_buf_head,size_t size_head,const lite::Context * context,bool train_mode,const lite::TrainCfg * cfg)234 static session::LiteSession *CreateTransferSessionInt(const char *model_buf_backbone, size_t size_backbone,
235 const char *model_buf_head, size_t size_head,
236 const lite::Context *context, bool train_mode,
237 const lite::TrainCfg *cfg) {
238 auto ValidModelSize = [](size_t size) -> bool {
239 constexpr size_t MaxModelSize = 1024 * 1024 * 1024ULL; // 1G B
240 return size < MaxModelSize && size > 0;
241 };
242 if (!ValidModelSize(size_backbone)) {
243 MS_LOG(ERROR) << "size_backbone too large: " << size_backbone;
244 return nullptr;
245 }
246 if (!ValidModelSize(size_head)) {
247 MS_LOG(ERROR) << "size_head too large: " << size_head;
248 return nullptr;
249 }
250 auto session = new (std::nothrow) lite::TransferSession(model_buf_backbone, size_backbone, context);
251 if (session == nullptr) {
252 MS_LOG(ERROR) << "create transfer session failed";
253 return nullptr;
254 }
255 if (!session->is_valid()) {
256 MS_LOG(ERROR) << "create transfer session failed";
257 delete session;
258 return nullptr;
259 }
260
261 mindspore::lite::InnerContext *inner_context = new (std::nothrow) mindspore::lite::InnerContext(context);
262 auto ret = session->Init(inner_context, cfg);
263 if (ret != lite::RET_OK) {
264 MS_LOG(ERROR) << "init transfer session failed";
265 delete session;
266 return nullptr;
267 }
268
269 auto model = std::shared_ptr<lite::Model>(lite::Model::Import(model_buf_head, size_head));
270 if (model == nullptr) {
271 MS_LOG(ERROR) << "create model for head train session failed";
272 delete session;
273 return nullptr;
274 }
275
276 ret = session->CompileTrainGraph(model);
277 if (ret != lite::RET_OK) {
278 MS_LOG(ERROR) << "Compiling Train Graph failed";
279 delete session;
280 return nullptr;
281 }
282 ret = session->CompileTransferGraph();
283 if (ret != lite::RET_OK) {
284 MS_LOG(ERROR) << "Compiling Transfer Graph failed";
285 delete session;
286 return nullptr;
287 }
288
289 if (train_mode) {
290 ret = session->Train();
291 } else {
292 ret = session->Eval();
293 }
294 if (ret != lite::RET_OK) {
295 MS_LOG(ERROR) << "Could not switch to Train Mode " << train_mode;
296 delete session;
297 return nullptr;
298 }
299 return session;
300 }
301
CreateTransferSession(const std::string & filename_backbone,const std::string & filename_head,const lite::Context * ctxt,bool train_mode,const lite::TrainCfg * cfg)302 session::LiteSession *session::TrainSession::CreateTransferSession(const std::string &filename_backbone,
303 const std::string &filename_head,
304 const lite::Context *ctxt, bool train_mode,
305 const lite::TrainCfg *cfg) {
306 size_t size_head = 0;
307 size_t size_backbone = 0;
308 std::string filename = filename_head;
309 if (filename.substr(filename.find_last_of(".") + 1) != "ms") {
310 filename = filename + ".ms";
311 }
312
313 auto buf_head = lite::ReadFile(filename.c_str(), &size_head);
314 if (buf_head == nullptr) {
315 return nullptr;
316 }
317 filename = filename_backbone;
318 if (filename.substr(filename.find_last_of(".") + 1) != "ms") {
319 filename = filename + ".ms";
320 }
321
322 auto buf_backbone = lite::ReadFile(filename.c_str(), &size_backbone);
323 if (buf_backbone == nullptr) {
324 return nullptr;
325 }
326 return CreateTransferSessionInt(buf_backbone, size_backbone, buf_head, size_head, ctxt, train_mode, cfg);
327 }
328 } // namespace mindspore
329