• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/train/transfer_session.h"
18 #include <sys/stat.h>
19 #include <algorithm>
20 #include <utility>
21 #include <vector>
22 #include <iostream>
23 #include <fstream>
24 #include <memory>
25 #include "include/errorcode.h"
26 #include "src/common/utils.h"
27 #include "src/common/file_utils.h"
28 #include "src/tensor.h"
29 #include "src/train/loss_kernel.h"
30 #include "src/train/optimizer_kernel.h"
31 #include "src/sub_graph_kernel.h"
32 #include "src/train/train_populate_parameter.h"
33 #include "src/executor.h"
34 #include "src/kernel_registry.h"
35 #include "src/runtime/kernel/arm/fp32_grad/convolution.h"
36 #include "nnacl/fp32/pack_fp32.h"
37 #include "src/train/train_export.h"
38 #include "src/train/train_utils.h"
39 
40 namespace mindspore {
41 namespace lite {
TransferSession(const char * model_buf_backbone,size_t size_backbone,const lite::Context * context)42 TransferSession::TransferSession(const char *model_buf_backbone, size_t size_backbone, const lite::Context *context)
43     : is_valid_(false) {
44   lite_model_ = reinterpret_cast<char *>(malloc(size_backbone));
45   size_backbone_ = size_backbone;
46   if (lite_model_ != nullptr) {
47     std::copy(model_buf_backbone, model_buf_backbone + size_backbone, lite_model_);
48     backbone_session_ =
49       reinterpret_cast<lite::LiteSession *>(session::LiteSession::CreateSession(lite_model_, size_backbone, context));
50     if (backbone_session_ != nullptr) {
51       is_valid_ = true;
52     } else {
53       MS_LOG(ERROR) << "transfer session: create backbone session failed";
54     }
55   }
56 }
57 
GetInputs() const58 std::vector<tensor::MSTensor *> TransferSession::GetInputs() const { return combined_inputs_; }
59 
CompileFormatTransform(tensor::MSTensor * out,tensor::MSTensor * in,int * mask,size_t mask_len)60 bool TransferSession::CompileFormatTransform(tensor::MSTensor *out, tensor::MSTensor *in, int *mask, size_t mask_len) {
61   MS_ASSERT(out->shape().size() == mask_len);
62   for (std::size_t dim = 0; dim != out->shape().size(); ++dim) {
63     if (in->shape().at(mask[dim]) != out->shape().at(dim)) {
64       return false;
65     }
66   }
67   return true;
68 }
69 
CompileTransferGraph()70 int TransferSession::CompileTransferGraph() {
71   combined_inputs_ = backbone_session_->GetInputs();
72   auto outputs_backbone = backbone_session_->GetOutputs();
73   auto inputs_head = lite::TrainSession::GetInputs();
74 
75   int ret = RET_OK;
76   for (auto input : inputs_head) {
77     bool match = false;
78     mindspore::tensor::MSTensor *output = nullptr;
79     for (auto it = outputs_backbone.begin(); it != outputs_backbone.end(); ++it) {
80       output = it->second;
81       if (output->ElementsNum() == input->ElementsNum() && output->shape().size() == input->shape().size()) {
82         match = true;
83         for (std::size_t dim = 0; dim != output->shape().size(); ++dim) {
84           if (input->shape().at(dim) != output->shape().at(dim)) {
85             match = false;
86             break;
87           }
88         }
89         if (match == false && input->shape().size() == 4) {
90           int nchw2nhwc_mask[4] = {0, 3, 1, 2};
91           nchw2nhwc_ = CompileFormatTransform(output, input, nchw2nhwc_mask, 4);
92           match = nchw2nhwc_;
93         }
94         if (match) {
95           break;
96         }
97       }
98     }
99     if (match) {
100       backbone_head_map_.push_back(std::make_pair(input, output));
101     } else {
102       combined_inputs_.push_back(input);
103     }
104   }
105   if (backbone_head_map_.size() == 0) {
106     ret = RET_ERROR;
107   }
108   return ret;
109 }
110 
GetInputsByTensorName(const std::string & tensor_name) const111 mindspore::tensor::MSTensor *TransferSession::GetInputsByTensorName(const std::string &tensor_name) const {
112   /* First look in backbone netwok */
113   auto ret = backbone_session_->GetInputsByTensorName(tensor_name);
114   /* If not found look in head network */
115   if (ret == nullptr) {
116     ret = TrainSession::GetInputsByTensorName(tensor_name);
117   }
118   return ret;
119 }
120 
~TransferSession()121 TransferSession::~TransferSession() {
122   if (backbone_session_ != nullptr) {
123     delete backbone_session_;
124     backbone_session_ = nullptr;
125   }
126   if (lite_model_ != nullptr) {
127     free(lite_model_);
128     lite_model_ = nullptr;
129   }
130 }
BindThread(bool if_bind)131 void TransferSession::BindThread(bool if_bind) {
132   backbone_session_->BindThread(if_bind);
133   TrainSession::BindThread(if_bind);
134 }
135 
RunGraph(const KernelCallBack & before,const KernelCallBack & after)136 int TransferSession::RunGraph(const KernelCallBack &before, const KernelCallBack &after) {
137   auto ret = backbone_session_->RunGraph(before, after);
138   if (ret != RET_OK) {
139     return ret;
140   }
141   for (auto &backbone_head_pair : backbone_head_map_) {
142     auto input = backbone_head_pair.first;
143     auto output = backbone_head_pair.second;
144     char *input_data = reinterpret_cast<char *>(input->MutableData());
145     char *output_data = reinterpret_cast<char *>(output->MutableData());
146     if (nchw2nhwc_) {
147       int plane = input->shape().at(1) * input->shape().at(2);
148       int batch = input->shape().at(0);
149       int channel = input->shape().at(3);
150       PackNCHWToNHWCFp32(output_data, input_data, batch, plane, channel, 0, 1);
151     } else {
152       std::copy(output_data, output_data + output->Size(), input_data);
153     }
154   }
155   ret = lite::TrainSession::RunGraph(before, after);
156   return ret;
157 }
158 
ConnectionMap()159 std::unordered_map<size_t, size_t> TransferSession::ConnectionMap() {
160   std::unordered_map<size_t, size_t> map;
161   for (auto &backbone_head_pair : backbone_head_map_) {
162     auto input = backbone_head_pair.first;
163     auto output = backbone_head_pair.second;
164     auto in_id = TSFindTensorByName(tensors_, input->tensor_name());
165     if (in_id == tensors_.size()) {
166       MS_LOG(ERROR) << "cannot find input tensor " << input->tensor_name();
167       map.clear();
168       return map;
169     }
170     auto out_id = TSFindTensorByName(backbone_session_->tensors_, output->tensor_name());
171     if (out_id == backbone_session_->tensors_.size()) {
172       MS_LOG(ERROR) << "cannot find input tensor " << output->tensor_name();
173       map.clear();
174       return map;
175     }
176     map[in_id] = out_id;
177   }
178   return map;
179 }
180 
Export(const std::string & filename,ModelType model_type,QuantizationType quant_type,FormatType format,std::vector<std::string> out_put_tensor_name)181 int TransferSession::Export(const std::string &filename, ModelType model_type, QuantizationType quant_type,
182                             FormatType format, std::vector<std::string> out_put_tensor_name) {
183   if (format != FT_FLATBUFFERS) {
184     MS_LOG(ERROR) << "Currently only flatbuffer format is supported";
185     return RET_ERROR;
186   }
187 
188   if (model_type == MT_TRAIN) {
189     return TrainSession::Export(filename, model_type, quant_type, format);
190   }
191 
192   bool orig_train_state = IsTrain();
193   Eval();
194   TrainExport texport(filename);
195   int status = texport.LoadModel(lite_model_, size_backbone_);
196   if (status != RET_OK) {
197     MS_LOG(ERROR) << "cannot init export";
198     return status;
199   }
200   auto connect_map = ConnectionMap();
201   texport.set_connect(connect_map);
202   if (nchw2nhwc_) {
203     status = texport.AddTransformNode();
204     if (status != RET_OK) {
205       MS_LOG(ERROR) << "cannot add transform node";
206       return status;
207     }
208   }
209   if (!out_put_tensor_name.empty() && model_type == MT_INFERENCE) {
210     std::vector<kernel::LiteKernel *> export_kernels = {};
211     status = FindExportKernels(&export_kernels, out_put_tensor_name, inference_kernels_);
212     if (status != RET_OK) {
213       MS_LOG(ERROR) << "FindExportKernels failed.";
214       return RET_ERROR;
215     }
216     status = texport.ExportNet(export_kernels, tensors_, out_put_tensor_name, model_.get(), quant_type);
217   } else {
218     status = texport.ExportNet(inference_kernels_, tensors_, GetOutputTensorNames(), model_.get(), quant_type);
219   }
220   if (status != RET_OK) {
221     MS_LOG(ERROR) << "cannot serialize head";
222     return status;
223   }
224   status = texport.SaveToFile();
225   if (status != RET_OK) {
226     MS_LOG(ERROR) << "failed to save to " << filename;
227     return status;
228   }
229   if (orig_train_state) Train();
230   return status;
231 }
232 }  // namespace lite
233 
CreateTransferSessionInt(const char * model_buf_backbone,size_t size_backbone,const char * model_buf_head,size_t size_head,const lite::Context * context,bool train_mode,const lite::TrainCfg * cfg)234 static session::LiteSession *CreateTransferSessionInt(const char *model_buf_backbone, size_t size_backbone,
235                                                       const char *model_buf_head, size_t size_head,
236                                                       const lite::Context *context, bool train_mode,
237                                                       const lite::TrainCfg *cfg) {
238   auto ValidModelSize = [](size_t size) -> bool {
239     constexpr size_t MaxModelSize = 1024 * 1024 * 1024ULL;  // 1G B
240     return size < MaxModelSize && size > 0;
241   };
242   if (!ValidModelSize(size_backbone)) {
243     MS_LOG(ERROR) << "size_backbone too large: " << size_backbone;
244     return nullptr;
245   }
246   if (!ValidModelSize(size_head)) {
247     MS_LOG(ERROR) << "size_head too large: " << size_head;
248     return nullptr;
249   }
250   auto session = new (std::nothrow) lite::TransferSession(model_buf_backbone, size_backbone, context);
251   if (session == nullptr) {
252     MS_LOG(ERROR) << "create transfer session failed";
253     return nullptr;
254   }
255   if (!session->is_valid()) {
256     MS_LOG(ERROR) << "create transfer session failed";
257     delete session;
258     return nullptr;
259   }
260 
261   mindspore::lite::InnerContext *inner_context = new (std::nothrow) mindspore::lite::InnerContext(context);
262   auto ret = session->Init(inner_context, cfg);
263   if (ret != lite::RET_OK) {
264     MS_LOG(ERROR) << "init transfer session failed";
265     delete session;
266     return nullptr;
267   }
268 
269   auto model = std::shared_ptr<lite::Model>(lite::Model::Import(model_buf_head, size_head));
270   if (model == nullptr) {
271     MS_LOG(ERROR) << "create model for head train session failed";
272     delete session;
273     return nullptr;
274   }
275 
276   ret = session->CompileTrainGraph(model);
277   if (ret != lite::RET_OK) {
278     MS_LOG(ERROR) << "Compiling Train Graph failed";
279     delete session;
280     return nullptr;
281   }
282   ret = session->CompileTransferGraph();
283   if (ret != lite::RET_OK) {
284     MS_LOG(ERROR) << "Compiling Transfer Graph failed";
285     delete session;
286     return nullptr;
287   }
288 
289   if (train_mode) {
290     ret = session->Train();
291   } else {
292     ret = session->Eval();
293   }
294   if (ret != lite::RET_OK) {
295     MS_LOG(ERROR) << "Could not switch to Train Mode " << train_mode;
296     delete session;
297     return nullptr;
298   }
299   return session;
300 }
301 
CreateTransferSession(const std::string & filename_backbone,const std::string & filename_head,const lite::Context * ctxt,bool train_mode,const lite::TrainCfg * cfg)302 session::LiteSession *session::TrainSession::CreateTransferSession(const std::string &filename_backbone,
303                                                                    const std::string &filename_head,
304                                                                    const lite::Context *ctxt, bool train_mode,
305                                                                    const lite::TrainCfg *cfg) {
306   size_t size_head = 0;
307   size_t size_backbone = 0;
308   std::string filename = filename_head;
309   if (filename.substr(filename.find_last_of(".") + 1) != "ms") {
310     filename = filename + ".ms";
311   }
312 
313   auto buf_head = lite::ReadFile(filename.c_str(), &size_head);
314   if (buf_head == nullptr) {
315     return nullptr;
316   }
317   filename = filename_backbone;
318   if (filename.substr(filename.find_last_of(".") + 1) != "ms") {
319     filename = filename + ".ms";
320   }
321 
322   auto buf_backbone = lite::ReadFile(filename.c_str(), &size_backbone);
323   if (buf_backbone == nullptr) {
324     return nullptr;
325   }
326   return CreateTransferSessionInt(buf_backbone, size_backbone, buf_head, size_head, ctxt, train_mode, cfg);
327 }
328 }  // namespace mindspore
329