• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/train/transfer_session.h"
18 #include <sys/stat.h>
19 #include <algorithm>
20 #include <utility>
21 #include <vector>
22 #include <iostream>
23 #include <fstream>
24 #include <memory>
25 #include "include/errorcode.h"
26 #include "src/common/utils.h"
27 #include "src/common/file_utils.h"
28 #include "src/tensor.h"
29 #include "src/train/loss_kernel.h"
30 #include "src/train/optimizer_kernel.h"
31 #include "src/train/train_populate_parameter.h"
32 #include "src/litert/executor.h"
33 #include "src/train/train_export.h"
34 #include "src/train/train_utils.h"
35 
36 namespace mindspore {
37 namespace lite {
TransferSession(const char * model_buf_backbone,size_t size_backbone,const std::shared_ptr<lite::InnerContext> & context)38 TransferSession::TransferSession(const char *model_buf_backbone, size_t size_backbone,
39                                  const std::shared_ptr<lite::InnerContext> &context)
40     : is_valid_(false) {
41   lite_model_ = reinterpret_cast<char *>(malloc(size_backbone));
42   size_backbone_ = size_backbone;
43   if (lite_model_ != nullptr) {
44     std::copy(model_buf_backbone, model_buf_backbone + size_backbone, lite_model_);
45     backbone_session_ = LiteSession::CreateSession(lite_model_, size_backbone, context);
46     if (backbone_session_ != nullptr) {
47       is_valid_ = true;
48     } else {
49       MS_LOG(ERROR) << "transfer session: create backbone session failed";
50     }
51   }
52 }
53 
GetInputs() const54 std::vector<lite::Tensor *> TransferSession::GetInputs() const { return combined_inputs_; }
55 
CompileFormatTransform(lite::Tensor * out,lite::Tensor * in,int * mask,size_t mask_len)56 bool TransferSession::CompileFormatTransform(lite::Tensor *out, lite::Tensor *in, int *mask, size_t mask_len) {
57   MS_ASSERT(out->shape().size() == mask_len);
58   for (std::size_t dim = 0; dim != out->shape().size(); ++dim) {
59     if (in->shape().at(mask[dim]) != out->shape().at(dim)) {
60       return false;
61     }
62   }
63   return true;
64 }
65 
CompileTransferGraph()66 int TransferSession::CompileTransferGraph() {
67   combined_inputs_ = backbone_session_->GetInputs();
68   auto outputs_backbone = backbone_session_->GetOutputs();
69   auto inputs_head = lite::TrainSession::GetInputs();
70 
71   int ret = RET_OK;
72   for (auto input : inputs_head) {
73     bool match = false;
74     mindspore::lite::Tensor *output = nullptr;
75     for (auto it = outputs_backbone.begin(); it != outputs_backbone.end(); ++it) {
76       output = it->second;
77       if (output->ElementsNum() == input->ElementsNum() && output->shape().size() == input->shape().size()) {
78         match = true;
79         for (std::size_t dim = 0; dim != output->shape().size(); ++dim) {
80           if (input->shape().at(dim) != output->shape().at(dim)) {
81             match = false;
82             break;
83           }
84         }
85         if (match == false && input->shape().size() == 4) {
86           int nchw2nhwc_mask[4] = {0, 3, 1, 2};
87           nchw2nhwc_ = CompileFormatTransform(output, input, nchw2nhwc_mask, 4);
88           match = nchw2nhwc_;
89         }
90         if (match) {
91           break;
92         }
93       }
94     }
95     if (match) {
96       backbone_head_map_.push_back(std::make_pair(input, output));
97     } else {
98       combined_inputs_.push_back(input);
99     }
100   }
101   if (backbone_head_map_.size() == 0) {
102     ret = RET_ERROR;
103   }
104   return ret;
105 }
106 
GetInputsByTensorName(const std::string & tensor_name) const107 mindspore::lite::Tensor *TransferSession::GetInputsByTensorName(const std::string &tensor_name) const {
108   /* First look in backbone netwok */
109   auto ret = backbone_session_->GetInputsByTensorName(tensor_name);
110   /* If not found look in head network */
111   if (ret == nullptr) {
112     ret = TrainSession::GetInputsByTensorName(tensor_name);
113   }
114   return ret;
115 }
116 
~TransferSession()117 TransferSession::~TransferSession() {
118   if (backbone_session_ != nullptr) {
119     delete backbone_session_;
120     backbone_session_ = nullptr;
121   }
122   if (lite_model_ != nullptr) {
123     free(lite_model_);
124     lite_model_ = nullptr;
125   }
126 }
BindThread(bool if_bind)127 void TransferSession::BindThread(bool if_bind) {
128   backbone_session_->BindThread(if_bind);
129   TrainSession::BindThread(if_bind);
130 }
131 
RunGraph(const KernelCallBack & before,const KernelCallBack & after)132 int TransferSession::RunGraph(const KernelCallBack &before, const KernelCallBack &after) {
133   auto ret = backbone_session_->RunGraph(before, after);
134   if (ret != RET_OK) {
135     return ret;
136   }
137   for (auto &backbone_head_pair : backbone_head_map_) {
138     auto input = backbone_head_pair.first;
139     auto output = backbone_head_pair.second;
140     float *input_data = reinterpret_cast<float *>(input->MutableData());
141     float *output_data = reinterpret_cast<float *>(output->MutableData());
142     if (nchw2nhwc_) {
143       int batch = input->shape().at(0);
144       int plane = input->shape().at(1) * input->shape().at(2);
145       int channel = input->shape().at(3);
146       int img_size = plane * channel;
147       for (int b = 0; b < batch; b++) {
148         float *in = input_data + b * img_size;
149         float *out = output_data + b * img_size;
150         for (int p = 0; p < plane; p++) {
151           for (int c = 0; c < channel; c++) {
152             in[p * channel + c] = out[c * plane + p];
153           }
154         }
155       }
156     } else {
157       std::copy(output_data, output_data + output->ElementsNum(), input_data);
158     }
159   }
160   ret = lite::TrainSession::RunGraph(before, after);
161   return ret;
162 }
163 
ConnectionMap()164 std::unordered_map<size_t, size_t> TransferSession::ConnectionMap() {
165   std::unordered_map<size_t, size_t> map;
166   for (auto &backbone_head_pair : backbone_head_map_) {
167     auto input = backbone_head_pair.first;
168     auto output = backbone_head_pair.second;
169     auto in_id = TSFindTensorByName(tensors_, input->tensor_name());
170     if (in_id == tensors_.size()) {
171       MS_LOG(ERROR) << "cannot find input tensor " << input->tensor_name();
172       map.clear();
173       return map;
174     }
175     auto out_id = TSFindTensorByName(backbone_session_->tensors_, output->tensor_name());
176     if (out_id == backbone_session_->tensors_.size()) {
177       MS_LOG(ERROR) << "cannot find input tensor " << output->tensor_name();
178       map.clear();
179       return map;
180     }
181     map[in_id] = out_id;
182   }
183   return map;
184 }
185 
186 template <typename DestType>
ExportInner(DestType destination,ModelType model_type,QuantizationType quant_type,FormatType format,std::vector<std::string> out_put_tensor_name)187 int TransferSession::ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type,
188                                  FormatType format, std::vector<std::string> out_put_tensor_name) {
189   if constexpr (std::is_same_v<DestType, const std::string &>) {
190     MS_CHECK_FALSE_MSG(destination.empty(), RET_ERROR, "File name cannot be empty");
191   } else if constexpr (std::is_same_v<DestType, Buffer *>) {
192     MS_CHECK_FALSE_MSG(destination == nullptr, RET_ERROR, "model buffer cannot be nullptr");
193   } else {
194     MS_LOG(ERROR) << "Unsupported destination.";
195     return RET_ERROR;
196   }
197   if (format != FT_FLATBUFFERS) {
198     MS_LOG(ERROR) << "Currently only flatbuffer format is supported";
199     return RET_ERROR;
200   }
201 
202   if (model_type == MT_TRAIN) {
203     return TrainSession::Export(destination, model_type, quant_type, format);
204   }
205 
206   bool orig_train_state = IsTrain();
207   if (Eval() != RET_OK) {
208     MS_LOG(ERROR) << "eval failed.";
209     return RET_ERROR;
210   }
211   TrainExport texport(destination);
212   int status = texport.LoadModel(lite_model_, size_backbone_);
213   if (status != RET_OK) {
214     MS_LOG(ERROR) << "cannot init export";
215     return status;
216   }
217   auto connect_map = ConnectionMap();
218   texport.set_connect(connect_map);
219   if (nchw2nhwc_) {
220     status = texport.AddTransformNode();
221     if (status != RET_OK) {
222       MS_LOG(ERROR) << "cannot add transform node";
223       return status;
224     }
225   }
226   if (!out_put_tensor_name.empty() && model_type == MT_INFERENCE) {
227     std::vector<kernel::KernelExec *> export_kernels = {};
228     status = FindExportKernels(&export_kernels, out_put_tensor_name, inference_kernels_);
229     if (status != RET_OK) {
230       MS_LOG(ERROR) << "FindExportKernels failed.";
231       return RET_ERROR;
232     }
233     status = texport.ExportNet(export_kernels, tensors_, {}, out_put_tensor_name, model_.get(), quant_type,
234                                backbone_session_->model_);
235   } else {
236     status = texport.ExportNet(inference_kernels_, tensors_, {}, GetOutputTensorNames(), model_.get(), quant_type,
237                                backbone_session_->model_);
238   }
239   if (status != RET_OK) {
240     MS_LOG(ERROR) << "cannot serialize head";
241     return status;
242   }
243 
244   if constexpr (std::is_same_v<DestType, const std::string &>) {
245     status = texport.SaveToFile();
246     if (status != RET_OK) {
247       MS_LOG(ERROR) << "failed to save to " << destination;
248       return status;
249     }
250   } else {
251     status = texport.SaveToBuffer();
252     MS_CHECK_FALSE_MSG(status != RET_OK, status, "fail to save to model buffer.");
253   }
254 
255   if (orig_train_state) {
256     auto ret = Train();
257     if (ret != RET_OK) {
258       MS_LOG(ERROR) << "train failed.";
259       return RET_ERROR;
260     }
261   }
262   return status;
263 }
264 
Export(const std::string & filename,ModelType model_type,QuantizationType quant_type,FormatType format,std::vector<std::string> out_put_tensor_name)265 int TransferSession::Export(const std::string &filename, ModelType model_type, QuantizationType quant_type,
266                             FormatType format, std::vector<std::string> out_put_tensor_name) {
267   return ExportInner<const std::string &>(filename, model_type, quant_type, format, out_put_tensor_name);
268 }
269 
Export(Buffer * model_buffer,ModelType model_type,QuantizationType quant_type,FormatType format,std::vector<std::string> out_put_tensor_name)270 int TransferSession::Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType format,
271                             std::vector<std::string> out_put_tensor_name) {
272   return ExportInner<Buffer *>(model_buffer, model_type, quant_type, format, out_put_tensor_name);
273 }
274 
CreateTransferSessionInt(const char * model_buf_backbone,size_t size_backbone,const char * model_buf_head,size_t size_head,const std::shared_ptr<InnerContext> & context,bool train_mode,const lite::TrainCfg * cfg)275 lite::LiteSession *CreateTransferSessionInt(const char *model_buf_backbone, size_t size_backbone,
276                                             const char *model_buf_head, size_t size_head,
277                                             const std::shared_ptr<InnerContext> &context, bool train_mode,
278                                             const lite::TrainCfg *cfg) {
279   auto ValidModelSize = [](size_t size) -> bool {
280     constexpr size_t MaxModelSize = 1024 * 1024 * 1024ULL;  // 1G B
281     return size < MaxModelSize && size > 0;
282   };
283   if (!ValidModelSize(size_backbone)) {
284     MS_LOG(ERROR) << "size_backbone too large: " << size_backbone;
285     return nullptr;
286   }
287   if (!ValidModelSize(size_head)) {
288     MS_LOG(ERROR) << "size_head too large: " << size_head;
289     return nullptr;
290   }
291   auto session = new (std::nothrow) lite::TransferSession(model_buf_backbone, size_backbone, context);
292   if (session == nullptr) {
293     MS_LOG(ERROR) << "create transfer session failed";
294     return nullptr;
295   }
296   if (!session->is_valid()) {
297     MS_LOG(ERROR) << "create transfer session failed";
298     delete session;
299     return nullptr;
300   }
301   auto ret = session->TrainInit(context, cfg);
302   if (ret != lite::RET_OK) {
303     MS_LOG(ERROR) << "init transfer session failed";
304     delete session;
305     return nullptr;
306   }
307 
308   auto model = std::shared_ptr<lite::Model>(lite::Model::Import(model_buf_head, size_head));
309   if (model == nullptr) {
310     MS_LOG(ERROR) << "create model for head train session failed";
311     delete session;
312     return nullptr;
313   }
314 
315   ret = session->CompileTrainGraph(model);
316   if (ret != lite::RET_OK) {
317     MS_LOG(ERROR) << "Compiling Train Graph failed";
318     delete session;
319     return nullptr;
320   }
321   ret = session->CompileTransferGraph();
322   if (ret != lite::RET_OK) {
323     MS_LOG(ERROR) << "Compiling Transfer Graph failed";
324     delete session;
325     return nullptr;
326   }
327 
328   if (train_mode) {
329     ret = session->Train();
330   } else {
331     ret = session->Eval();
332   }
333   if (ret != lite::RET_OK) {
334     MS_LOG(ERROR) << "Could not switch to Train Mode " << train_mode;
335     delete session;
336     return nullptr;
337   }
338   return session;
339 }
340 }  // namespace lite
341 }  // namespace mindspore
342