• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/train/train_session.h"
18 #include <sys/stat.h>
19 #include <algorithm>
20 #include <utility>
21 #include <vector>
22 #include <iostream>
23 #include <fstream>
24 #include <memory>
25 #include <queue>
26 #include <map>
27 #include <set>
28 #include "include/errorcode.h"
29 #include "src/litert/lite_model.h"
30 #include "src/litert/kernel_exec_util.h"
31 #include "src/tensor.h"
32 #include "src/litert/kernel_registry.h"
33 #include "src/common/prim_util.h"
34 #include "src/common/tensor_util.h"
35 #include "src/common/utils.h"
36 #include "src/train/optimizer_kernel.h"
37 #include "src/train/train_utils.h"
38 #include "src/train/train_export.h"
39 #include "src/train/opt_allocator.h"
40 #include "src/train/static_allocator.h"
41 #include "src/train/train_populate_parameter.h"
42 #include "src/train/train_populate_parameter_v0.h"
43 
44 namespace mindspore {
45 namespace lite {
46 namespace {
FreeGradients(const std::vector<lite::Tensor * > & gradients)47 void FreeGradients(const std::vector<lite::Tensor *> &gradients) {
48   for (auto &gradient : gradients) {
49     delete gradient;
50   }
51 }  // namespace
52 
AddNonConstTrainableParams(const std::vector<kernel::KernelExec * > & in_kernels,kernel::OptimizerKernel * optimizer,std::vector<lite::Tensor * > * params)53 void AddNonConstTrainableParams(const std::vector<kernel::KernelExec *> &in_kernels, kernel::OptimizerKernel *optimizer,
54                                 std::vector<lite::Tensor *> *params) {
55   auto indices = optimizer->GetTrainableParamsIdxs();
56   if (params->size() == indices.size()) {
57     return;
58   }
59   for (size_t ix = 0; ix < indices.size(); ix++) {
60     auto param = optimizer->in_tensors().at(indices[ix]);
61     if (param->IsConst()) {
62       continue;
63     }
64     for (size_t i = 0; i < in_kernels.size(); i++) {
65       auto out_tensors = in_kernels.at(i)->out_tensors();
66       if (std::find(out_tensors.begin(), out_tensors.end(), param) != out_tensors.end() &&
67           !in_kernels.at(i)->in_tensors().empty()) {
68         auto filtered_tensor = in_kernels.at(i)->in_tensors().at(FIRST_INPUT);
69         if (filtered_tensor->IsConst()) {
70           params->emplace_back(filtered_tensor);
71           break;
72         }
73       }
74     }
75   }
76 }
77 }  // namespace
78 const char *kGradName = "Gradients";
79 const char *kOptimizerName = "optimizer";
80 constexpr auto kObfNodeName = "obf_op-obf_mul";
81 constexpr size_t kFloatSize = 4;
82 constexpr int kDataIndex = 1;
83 
TrainSession()84 TrainSession::TrainSession() {
85   is_train_session_ = true;
86   InitCallBack();
87 }
88 
TrainInit(const std::shared_ptr<InnerContext> & context,const TrainCfg * train_cfg)89 int TrainSession::TrainInit(const std::shared_ptr<InnerContext> &context, const TrainCfg *train_cfg) {
90   if (train_cfg != nullptr) {
91     if (train_cfg->mix_precision_cfg_.loss_scale_ <= 0) {
92       MS_LOG(ERROR) << "illegal loss scale configuration";
93       return RET_NULL_PTR;
94     }
95     cfg_ = *train_cfg;
96   }
97   if (context == nullptr) {
98     MS_LOG(ERROR) << "context cannot be nullptr";
99     return RET_NULL_PTR;
100   }
101   allocator_ = context->allocator;
102   return lite::LiteSession::Init(context);
103 }
104 
ReplaceOps()105 std::vector<CreatorOp> TrainSession::ReplaceOps() {
106   const std::vector<CreatorOp> replace = {
107     // currently no ops are Hijacked by TrainSession
108   };
109   mindspore::lite::KernelRegistry *reg = mindspore::lite::KernelRegistry::GetInstance();
110   std::vector<CreatorOp> results;
111   for (auto v : replace) {
112     const CreatorOp cl = make_tuple(std::get<0>(v), reg->GetCreator(std::get<0>(v)));
113     results.push_back(cl);
114     reg->RegKernel(std::get<0>(v), std::get<1>(v));
115   }
116   return results;
117 }
118 
RestoreOps(const std::vector<CreatorOp> & restore)119 void TrainSession::RestoreOps(const std::vector<CreatorOp> &restore) {
120   mindspore::lite::KernelRegistry *reg = mindspore::lite::KernelRegistry::GetInstance();
121   for (auto v : restore) {
122     reg->RegKernel(std::get<0>(v), std::get<1>(v));
123   }
124 }
125 
AllocWorkSpace()126 int TrainSession::AllocWorkSpace() {
127   size_t workspace_size = 0;
128   for (auto kernel : this->train_kernels_) {
129     if (workspace_size < static_cast<kernel::LiteKernel *>(kernel->kernel())->workspace_size()) {
130       workspace_size = static_cast<kernel::LiteKernel *>(kernel->kernel())->workspace_size();
131     }
132   }
133   workspace_ = malloc(workspace_size);
134   if (workspace_ == nullptr) {
135     MS_LOG(ERROR) << "cannot allocate " << workspace_size << " for workspace";
136     return RET_ERROR;
137   }
138   for (auto kernel : this->train_kernels_) {
139     static_cast<kernel::LiteKernel *>(kernel->kernel())->set_workspace(workspace_);
140   }
141   return RET_OK;
142 }
143 
FreeWorkSpace()144 void TrainSession::FreeWorkSpace() {
145   if (workspace_ != nullptr) {
146     free(workspace_);
147     workspace_ = nullptr;
148   }
149   for (auto kernel : this->train_kernels_) {
150     static_cast<kernel::LiteKernel *>(kernel->kernel())->FreeWorkspace();
151   }
152 }
153 
InitCallBack()154 int TrainSession::InitCallBack() {
155   sched_mix_precision_callback_ = [&](const LiteGraph::Node *node) {
156     if (!context_->IsCpuFloat16Enabled()) {
157       return false;
158     }
159     bool force_fp16 = false;
160     auto node_type = GetPrimitiveType(node->primitive_, SCHEMA_VERSION::SCHEMA_CUR);
161     if (node_type == schema::PrimitiveType_Cast) {
162       schema::Tensor *tensor = model_.get()->graph_.all_tensors_.at(node->input_indices_[0]);
163       if (tensor->dataType() == kNumberTypeFloat16) {
164         force_fp16 = true;
165       } else if (tensor->dataType() == kNumberTypeFloat32) {
166         return false;
167       }
168     } else {
169       auto in_size = node->input_indices_.size();
170       for (std::size_t k = 0; k < in_size; k++) {
171         schema::Tensor *tensor = model_->graph_.all_tensors_.at(node->input_indices_[k]);
172         if ((tensor->dataType() == kNumberTypeFloat16) && (tensor->nodeType() == NodeType_ValueNode)) {
173           force_fp16 = true;
174           break;
175         }
176       }
177     }
178 
179     const auto &node_name = node->name_;
180     bool is_fp16 = true;
181     if (!force_fp16) {
182       // optimizer runs in fp32
183       if (node_name.find(kOptimizerName) != std::string::npos) {
184         is_fp16 = false;
185       }
186       // loss function runs in fp32
187       auto v = get_loss_name();
188       for (auto &s : v) {
189         if (node_name.find(s) != std::string::npos) {
190           is_fp16 = false;
191           break;
192         }
193       }
194       // run bn according to user configuration
195       if ((cfg_.mix_precision_cfg_.keep_batchnorm_fp32_) &&
196           (node_type == schema::PrimitiveType_FusedBatchNorm || node_type == schema::PrimitiveType_BatchNorm ||
197            node_type == schema::PrimitiveType_BatchNormGrad)) {
198         is_fp16 = false;
199       }
200     }
201     MS_LOG(DEBUG) << "Debug: " << node_name << ((is_fp16) ? " fp16" : " fp32");
202     return is_fp16;
203   };
204   return RET_OK;
205 }
206 
AllocTensors(const std::vector<kernel::KernelExec * > & kernels)207 int TrainSession::AllocTensors(const std::vector<kernel::KernelExec *> &kernels) {
208   if (!IS_STATIC_ALLOCATOR(allocator_)) return RET_OK;
209   OptAllocator allocator;
210   std::unordered_map<lite::Tensor *, int> ref_count;
211   std::unordered_map<lite::Tensor *, size_t> offset_map;
212   int counter = 0;
213   uint32_t input_idx = 0;
214   for (auto &kernel : kernels) {
215     for (size_t i = 0; i < kernel->out_tensors().size(); i++) {
216       auto tensor = kernel->out_tensors().at(i);
217       bool in_place = false;
218       if (counter != 0) {
219         in_place = IsInPlaceTensor(kernel, i, ref_count, &input_idx);
220       }
221       counter++;
222       size_t offset;
223       if (in_place) {
224         offset = GetInplaceTensorOffset(kernel, offset_map, &ref_count, input_idx);
225       } else {
226         size_t size = tensor->Size();
227         offset = allocator.Malloc(size);
228       }
229       offset_map[tensor] = offset;
230       ref_count[tensor] = tensor->init_ref_count();
231     }
232     for (auto tensor : kernel->in_tensors()) {
233       if (tensor->category() == lite::Category::VAR) {
234         int count = ref_count[tensor] - 1;
235         ref_count[tensor] = count;
236         if (count == 0) {
237           allocator.Free(offset_map[tensor]);
238         }
239       }
240     }
241   }
242   // Set Tensor data
243   auto size = allocator.total_size();
244   if (size > tensors_data_size_) {
245     free(tensors_data_);
246     tensors_data_ = nullptr;
247   }
248   if (tensors_data_ == nullptr) {
249     auto buf = malloc(size);
250     if (buf == nullptr) {
251       MS_LOG(ERROR) << "cannot allocate buffer size" << size;
252       return RET_ERROR;
253     }
254     StaticAllocator *alloc = reinterpret_cast<StaticAllocator *>(allocator_.get());
255     alloc->SetContex(buf, size);
256     tensors_data_ = buf;
257     tensors_data_size_ = size;
258   }
259   for (auto kernel : train_kernels_) {
260     for (auto tensor : kernel->out_tensors()) {
261       auto it = offset_map.find(tensor);
262       if (it != offset_map.end()) {
263         tensor->set_data(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(tensors_data_) + it->second));
264       }
265     }
266   }
267   return RET_OK;
268 }
269 
CompileGraph(lite::Model * model)270 int TrainSession::CompileGraph(lite::Model *model) { return lite::RET_ERROR; }
271 
CompileTrainGraph(std::shared_ptr<Model> model)272 int TrainSession::CompileTrainGraph(std::shared_ptr<Model> model) {
273   model_ = model;
274   auto restore = ReplaceOps();
275   sched_cb_ = std::make_unique<SchedulerCb>(sched_mix_precision_callback_);
276   if (sched_cb_ == nullptr) {
277     MS_LOG(ERROR) << "Failed to create SchedulerCb node";
278     return RET_ERROR;
279   }
280 
281   if (reinterpret_cast<LiteModel *>(model_.get())->GetSchemaVersion() == SCHEMA_VERSION::SCHEMA_CUR) {
282     kernel::PopulateTrainParameters();
283   }
284 
285   auto ret = lite::LiteSession::CompileGraph(model_.get());
286   if (ret != RET_OK) {
287     MS_LOG(ERROR) << "failed to compile train model";
288     return RET_ERROR;
289   }
290   orig_output_node_map_ = output_node_map_;
291   orig_output_tensor_map_ = output_tensor_map_;
292   orig_output_tensor_names_ = output_tensor_names_;
293   for (auto inTensor : inputs_) inTensor->MutableData();
294   RestoreOps(restore);
295   CompileTrainKernels();      // Prepare a list of train kernels
296   CompileOptimizedKernels();  // Prepare a list of kernels which are optimized (weight update step)
297   CompileTrainableParams();   // Prepare trainable parameters of optimizers
298   CompileTrainOutputs();      // prepare outputs in train mode
299   CompileEvalOutputs();       // prepare outputs in eval mode
300   // Prepare a list of eval kernels
301   if (CompileInferenceKernels() != RET_OK) {
302     MS_LOG(WARNING) << "CompileInferenceKernels failed.";
303     return RET_ERROR;
304   }
305   ret = AllocWorkSpace();
306   if (ret != RET_OK) {
307     MS_LOG(ERROR) << "failed to allocate space";
308     return RET_ERROR;
309   }
310   ret = AllocTensors(train_kernels_);
311   if (ret != RET_OK) {
312     MS_LOG(ERROR) << "failed to allocate space";
313     return RET_ERROR;
314   }
315   // Prepare a list of kernels which are const folded
316   MS_CHECK_TRUE_MSG(CompileConstFoldedKernels() == RET_OK, RET_ERROR, "CompileConstFoldedKernels failed.");
317   return RET_OK;
318 }
319 
~TrainSession()320 TrainSession::~TrainSession() {
321   FreeWorkSpace();
322   if (tensors_data_ != nullptr) {
323     free(tensors_data_);
324     tensors_data_ = nullptr;
325   }
326 }
327 
ExecKernels(const KernelCallBack & before,const KernelCallBack & after,const std::vector<kernel::KernelExec * > & run_kernels)328 int TrainSession::ExecKernels(const KernelCallBack &before, const KernelCallBack &after,
329                               const std::vector<kernel::KernelExec *> &run_kernels) {
330   for (auto *kernel : run_kernels) {
331     MS_ASSERT(kernel != nullptr);
332     auto ret = kernel->Execute(before, after);
333     if (RET_OK != ret) {
334       MS_LOG(ERROR) << "Execute kernel failed, name: " << kernel->name();
335       return ret;
336     }
337   }
338   return RET_OK;
339 }
340 
RestoreTensorData()341 void TrainSession::RestoreTensorData() {
342   for (auto &restored_origin_tensor : restored_origin_tensors_) {
343     auto *origin_tensor = restored_origin_tensor.first;
344     auto *restored_tensor = restored_origin_tensor.second;
345     MS_ASSERT(origin_tensor != nullptr);
346     MS_ASSERT(restored_tensor != nullptr);
347 
348     bool own_data = restored_tensor->own_data();
349     if (origin_tensor->data() == nullptr) {
350       restored_tensor->FreeData();
351     } else {
352       origin_tensor->FreeData();
353     }
354     origin_tensor->set_data_type(restored_tensor->data_type());
355     origin_tensor->set_data(restored_tensor->data());
356     origin_tensor->set_own_data(own_data);
357   }
358 }
359 
FreeRestoreTensors()360 void TrainSession::FreeRestoreTensors() {
361   for (auto &restored_origin_tensor : restored_origin_tensors_) {
362     auto *restored_tensor = restored_origin_tensor.second;
363     restored_tensor->set_data(nullptr);
364     delete (restored_tensor);
365   }
366   restored_origin_tensors_.clear();
367 }
368 
IsLossTensor(Tensor * tensor)369 bool TrainSession::IsLossTensor(Tensor *tensor) {
370   MS_ASSERT(tensor != nullptr);
371   bool isLoss = false;
372   auto t_n = tensor->tensor_name();
373   auto v = get_loss_name();
374   for (auto &s : v) {
375     if (t_n.find(s) != std::string::npos) {
376       isLoss = true;
377       break;
378     }
379   }
380   return isLoss;
381 }
382 
AllInputsNeedScale(kernel::KernelExec * kernel)383 bool TrainSession::AllInputsNeedScale(kernel::KernelExec *kernel) {
384   auto type = kernel->type();
385   bool is_scale = false;
386   switch (type) {
387     case schema::PrimitiveType_AbsGrad:
388     case schema::PrimitiveType_AddFusion:
389     case schema::PrimitiveType_SubFusion:
390     case schema::PrimitiveType_AddN:
391       for (auto &tensor : kernel->in_tensors()) {
392         is_scale = is_scale || tensor->IsScale();
393       }
394       return (is_scale);
395     default:
396       return false;
397   }
398   return false;
399 }
400 
MixPrecisionPreProcess(kernel::KernelExec * kernel,float scale)401 int TrainSession::MixPrecisionPreProcess(kernel::KernelExec *kernel, float scale) {
402   auto kernel_type = kernel->desc().data_type;
403   auto all_scale = AllInputsNeedScale(kernel);
404 
405   for (auto &tensor : kernel->in_tensors()) {
406     if ((tensor->IsScale() == false) && ((!IsLossKernel(kernel) && IsLossTensor(tensor)) || (all_scale == true))) {
407       ScaleTensor(tensor, scale);
408     }
409     // adjust tensor data type
410     if (tensor->data_type() != kernel_type) {
411       auto restore_tensor = CastTensor(tensor, kernel_type, this->context_->device_and_pkg_support_fp16_);
412       if (restore_tensor != nullptr) {
413         restored_origin_tensors_[tensor] = restore_tensor;
414       }
415     }
416   }
417   return RET_OK;
418 }
419 
MixPrecisionPostProcess(kernel::KernelExec * kernel)420 int TrainSession::MixPrecisionPostProcess(kernel::KernelExec *kernel) {
421   RestoreTensorData();
422   FreeRestoreTensors();
423 
424   float scale = 1.0f;
425   auto all_scale = AllInputsNeedScale(kernel);
426   for (auto &tensor : kernel->in_tensors()) {
427     if (tensor->IsScale()) {
428       scale *= tensor->get_scale();
429       if (all_scale) {
430         break;
431       }
432     }
433   }
434   for (auto &tensor : kernel->out_tensors()) {
435     tensor->set_scale(scale);
436   }
437 
438   for (auto &tensor : kernel->in_tensors()) {
439     if (tensor->IsScale() && ((!IsLossKernel(kernel) && IsLossTensor(tensor)) || all_scale)) {
440       ScaleTensor(tensor, 1.0f / scale);
441     }
442   }
443   return RET_OK;
444 }
445 
MixPrecisionExecKernels(const KernelCallBack & before,const KernelCallBack & after,const std::vector<kernel::KernelExec * > & run_kernels)446 int TrainSession::MixPrecisionExecKernels(const KernelCallBack &before, const KernelCallBack &after,
447                                           const std::vector<kernel::KernelExec *> &run_kernels) {
448   float scale = cfg_.mix_precision_cfg_.loss_scale_;
449   for (auto *kernel : run_kernels) {
450     MS_ASSERT(kernel != nullptr);
451     auto ret = MixPrecisionPreProcess(kernel, scale);
452     if (ret != RET_OK) {
453       MS_LOG(ERROR) << "MixPrecisionPreProcess failed.";
454       return RET_ERROR;
455     }
456     ret = kernel->Execute(before, after);
457     if (RET_OK != ret) {
458       MixPrecisionPostProcess(kernel);
459       // decrease loss scale in case of nan or inf
460       if (ret == RET_OUT_OF_TENSOR_RANGE) {
461         bool is_dynamic_scale = cfg_.mix_precision_cfg_.dynamic_loss_scale_;
462         cfg_.mix_precision_cfg_.loss_scale_ = std::max(((is_dynamic_scale) ? (scale / 2.f) : scale), 1.0f);
463         num_of_not_nan_iter_ = 0;
464         return RET_OK;
465       }
466       MS_LOG(ERROR) << "Execute kernel failed, name: " << kernel->name();
467       return ret;
468     }
469     MixPrecisionPostProcess(kernel);
470   }
471   // increase dynamic loss scale if pass pass threshold
472   if (cfg_.mix_precision_cfg_.dynamic_loss_scale_) {
473     num_of_not_nan_iter_++;
474     if (num_of_not_nan_iter_ >= cfg_.mix_precision_cfg_.num_of_not_nan_iter_th_) {
475       cfg_.mix_precision_cfg_.loss_scale_ = std::min((cfg_.mix_precision_cfg_.loss_scale_ * 2.0f), 65536.0f);
476       num_of_not_nan_iter_ = 0;
477     }
478   }
479 
480   // cast output to FP32
481   if (train_mode_ == false) {
482     for (auto t : this->outputs_) {
483       if (t->data_type() == kNumberTypeFloat16) {
484         auto restore = CastTensor(t, kNumberTypeFloat32, this->context_->device_and_pkg_support_fp16_);
485         delete restore;
486       }
487     }
488   }
489   return RET_OK;
490 }
491 
RunGraph(const KernelCallBack & before,const KernelCallBack & after)492 int TrainSession::RunGraph(const KernelCallBack &before, const KernelCallBack &after) {
493   // check inputs
494   auto ret = CheckTensorsInvalid(inputs_);
495   if (ret != RET_OK) {
496     MS_LOG(ERROR) << "CheckInputs failed";
497     return ret;
498   }
499 
500   // build out tensor
501   this->outputs_.clear();
502   for (auto &ms_tensors : output_node_map_) {
503     for (auto &ms_tensor : ms_tensors.second) {
504       auto lite_tensor = static_cast<lite::Tensor *>(ms_tensor);
505       this->outputs_.push_back(lite_tensor);
506     }
507   }
508 
509   if (this->context_ == nullptr) {
510     MS_LOG(ERROR) << "context is null";
511     return lite::RET_NULL_PTR;
512   }
513   auto &run_kernels = (train_mode_) ? train_kernels_ : inference_kernels_;
514   if (context_->IsCpuFloat16Enabled()) {
515     ret = MixPrecisionExecKernels(before, after, run_kernels);
516   } else {
517     ret = ExecKernels(before, after, run_kernels);
518   }
519   if (ret != RET_OK) {
520     MS_LOG(ERROR) << "failed to run model kernels";
521     return ret;
522   }
523 
524   if (train_mode_ && (virtual_batch_multiplier_ != 0)) {
525     virtual_batch_idx_++;
526     if (virtual_batch_idx_ >= virtual_batch_multiplier_) {
527       virtual_batch_idx_ = 0;
528       ret = OptimizerStep();
529       if (ret != RET_OK) {
530         MS_LOG(ERROR) << "failed to optimize model weights";
531         return ret;
532       }
533     }
534   }
535   return RET_OK;
536 }
537 
Train()538 int TrainSession::Train() {
539   // shift kernels to train mode
540   train_mode_ = true;
541   virtual_batch_idx_ = 0;
542   for (auto &kernel : this->train_kernels_) {
543     MS_ASSERT(kernel != nullptr);
544     auto ret = kernel->Train();
545     if (ret != RET_OK) {
546       MS_LOG(ERROR) << kernel->name() << " failed to set train mode";
547       return RET_ERROR;
548     }
549   }
550   // set train outputs
551   output_node_map_ = train_output_node_map_;
552   output_tensor_map_ = train_output_tensor_map_;
553   output_tensor_names_ = train_output_tensor_names_;
554   kernel::KernelExecUtil::InitTensorInitRefCount(train_kernels_);
555   for (auto &ms_tensors : eval_output_node_map_) {  // Allow to look at prediction also during training
556     for (auto &ms_tensor : ms_tensors.second) {
557       lite::Tensor *lite_tensor = static_cast<lite::Tensor *>(ms_tensor);
558       lite_tensor->set_init_ref_count(lite_tensor->init_ref_count() + 1);
559     }
560   }
561   // allocate tensors
562   auto ret = AllocTensors(train_kernels_);
563   if (ret != RET_OK) {
564     MS_LOG(ERROR) << "failed to allocate tensor space";
565     return RET_ERROR;
566   }
567   return RET_OK;
568 }
569 
Eval()570 int TrainSession::Eval() {
571   // shift kernels to eval mode
572   train_mode_ = false;
573   virtual_batch_idx_ = 0;
574   for (auto &kernel : this->train_kernels_) {
575     MS_ASSERT(kernel != nullptr);
576     auto ret = kernel->Eval();
577     if (ret != RET_OK) {
578       MS_LOG(ERROR) << kernel->name() << " failed to set eval mode";
579       return RET_ERROR;
580     }
581   }
582   // set eval outputs
583   output_node_map_ = eval_output_node_map_;
584   output_tensor_map_ = eval_output_tensor_map_;
585   output_tensor_names_ = eval_output_tensor_names_;
586   kernel::KernelExecUtil::InitTensorInitRefCount(inference_kernels_);
587   for (auto &ms_tensors : eval_output_node_map_) {
588     for (auto &ms_tensor : ms_tensors.second) {
589       lite::Tensor *lite_tensor = static_cast<lite::Tensor *>(ms_tensor);
590       lite_tensor->set_init_ref_count(lite_tensor->init_ref_count() + 1);
591     }
592   }
593   auto ret = AllocTensors(inference_kernels_);
594   if (ret != RET_OK) {
595     MS_LOG(ERROR) << "failed to allocate space";
596     return RET_ERROR;
597   }
598   return RET_OK;
599 }
600 
CompileEvalOutputs()601 void TrainSession::CompileEvalOutputs() {
602   eval_output_node_map_.clear();
603   eval_output_tensor_map_.clear();
604   eval_output_tensor_names_.clear();
605   for (auto kernel : this->train_kernels_) {
606     if (!IsLossKernel(kernel) || IsGradKernel(kernel)) {
607       continue;
608     }
609     // if LossKernel and not GradKernel, deal with outputs
610     for (auto in_kernel : kernel->in_kernels()) {
611       bool is_loss = IsLossInKernel(in_kernel);
612       if (is_loss) {
613         continue;
614       }
615       // insert if not already in
616       auto out_tensors = TSFindTensors(in_kernel, kernel);
617       if (eval_output_node_map_.find(in_kernel->name()) != eval_output_node_map_.end()) {
618         auto exist_out_tensors = eval_output_node_map_[in_kernel->name()];
619         auto kernel_all_out_tensors = in_kernel->out_tensors();
620         eval_output_node_map_[in_kernel->name()] = {};
621         for (auto tensor : kernel_all_out_tensors) {
622           if (std::find(out_tensors.begin(), out_tensors.end(), tensor) != out_tensors.end() ||
623               std::find(exist_out_tensors.begin(), exist_out_tensors.end(), tensor) != exist_out_tensors.end()) {
624             eval_output_node_map_[in_kernel->name()].emplace_back(tensor);
625           }
626         }
627       } else {
628         eval_output_node_map_[in_kernel->name()] = out_tensors;
629       }
630       for (auto out_tensor : out_tensors) {
631         auto index = TSFindTensor(tensors_, out_tensor);
632         if (std::find(eval_output_tensor_names_.begin(), eval_output_tensor_names_.end(), out_tensor->tensor_name()) !=
633               eval_output_tensor_names_.end() ||
634             std::find(eval_output_tensor_names_.begin(), eval_output_tensor_names_.end(), std::to_string(index)) !=
635               eval_output_tensor_names_.end()) {
636           continue;
637         }
638         if (index != tensors_.size()) {
639           if (!out_tensor->tensor_name().empty()) {
640             eval_output_tensor_map_.insert(std::make_pair(out_tensor->tensor_name(), out_tensor));
641             eval_output_tensor_names_.emplace_back(out_tensor->tensor_name());
642           } else {
643             eval_output_tensor_map_.insert(std::make_pair(std::to_string(index), out_tensor));
644             eval_output_tensor_names_.emplace_back(std::to_string(index));
645           }
646         }
647       }
648     }
649   }
650   if (eval_output_node_map_.empty()) eval_output_node_map_ = orig_output_node_map_;
651   if (eval_output_tensor_map_.empty()) eval_output_tensor_map_ = orig_output_tensor_map_;
652   if (eval_output_tensor_names_.empty()) eval_output_tensor_names_ = orig_output_tensor_names_;
653 }
654 
CompileTrainOutputs()655 void TrainSession::CompileTrainOutputs() {
656   train_output_node_map_.clear();
657   train_output_tensor_map_.clear();
658   train_output_tensor_names_.clear();
659   for (auto kernel : this->train_kernels_) {
660     if (orig_output_node_map_.find(kernel->name()) == orig_output_node_map_.end()) continue;
661     // Mask out optimizer out tensors
662     if (IsMaskOutput(kernel)) continue;
663     // insert if not already in
664     if (train_output_node_map_.find(kernel->name()) == train_output_node_map_.end()) {
665       auto *ms_tensor = kernel->out_tensors().at(0);
666       if (ms_tensor != nullptr) {
667         train_output_node_map_[kernel->name()].emplace_back(ms_tensor);
668         auto index = TSFindTensor(tensors_, ms_tensor);
669         if (index != tensors_.size()) {
670           if (!ms_tensor->tensor_name().empty()) {
671             train_output_tensor_map_.insert(std::make_pair(ms_tensor->tensor_name(), ms_tensor));
672             train_output_tensor_names_.emplace_back(ms_tensor->tensor_name());
673           } else {
674             train_output_tensor_map_.insert(std::make_pair(std::to_string(index), ms_tensor));
675             train_output_tensor_names_.emplace_back(std::to_string(index));
676           }
677         }
678       }
679     }
680   }
681   if (train_output_node_map_.empty()) train_output_node_map_ = orig_output_node_map_;
682   if (train_output_tensor_map_.empty()) train_output_tensor_map_ = orig_output_tensor_map_;
683   if (train_output_tensor_names_.empty()) train_output_tensor_names_ = orig_output_tensor_names_;
684 }
685 
BuildInferenceKernelsRecursive(kernel::KernelExec * kernel,std::vector<kernel::KernelExec * > * v)686 void TrainSession::BuildInferenceKernelsRecursive(kernel::KernelExec *kernel, std::vector<kernel::KernelExec *> *v) {
687   MS_ASSERT(kernel != nullptr);
688   MS_ASSERT(v != nullptr);
689   if (std::find(v->begin(), v->end(), kernel) == v->end()) {  // kernel is not already in vector
690     for (auto in_node : kernel->in_kernels()) {
691       BuildInferenceKernelsRecursive(in_node, v);
692     }
693     if (!IsLossKernel(kernel)) v->push_back(kernel);
694   }
695 }
696 
CompileTrainKernels()697 void TrainSession::CompileTrainKernels() {
698   train_kernels_.clear();
699   std::vector<kernel::KernelExec *> train_kernels;
700   for (auto ori_kernel : kernels_) {
701     if (ori_kernel->subgraph_type() == kernel::kNotSubGraph) {
702       train_kernels.push_back(ori_kernel);
703     } else {
704       auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(ori_kernel);
705       std::copy(sub_graph->nodes().begin(), sub_graph->nodes().end(), std::back_inserter(train_kernels));
706     }
707   }
708   // For LSTM GPU operators are synchronized internally hence we need to add sync mechanizm to execution graph
709   for (auto k : train_kernels) {
710     if (k->type() == schema::PrimitiveType_LSTMGradWeight) {
711       // Find PrimitiveType_LSTMGradData that matches this PrimitiveType_LSTMGradWeight
712       for (auto mk : train_kernels) {
713         if (mk->type() == schema::PrimitiveType_LSTMGradData) {
714           if (k->in_tensors().at(C2NUM)->tensor_name() == mk->in_tensors().at(0)->tensor_name()) {
715             mk->AddOutKernel(k);
716             k->AddInKernel(mk);
717           }
718         }
719       }
720     }
721   }
722   std::queue<kernel::KernelExec *> queue;
723   for (auto k : train_kernels) {
724     auto in_kernels = k->in_kernels();
725     if (in_kernels.size() == 0) {
726       queue.push(k);
727     }
728   }
729   std::unordered_map<kernel::KernelExec *, int> map;
730   while (!queue.empty()) {
731     // pop first element
732     auto k = queue.front();
733     train_kernels_.push_back(k);
734     queue.pop();
735     for (auto &ok : k->out_kernels()) {
736       auto cnt_iter = map.find(ok);
737       if (cnt_iter == map.end()) {
738         int ref_cnt = ok->in_kernels().size();
739         map[ok] = ref_cnt;
740       }
741       cnt_iter = map.find(ok);
742       auto ref_cnt = cnt_iter->second - 1;
743       map[ok] = ref_cnt;
744       if (ref_cnt <= 0) {
745         queue.push(cnt_iter->first);
746       }
747     }
748   }
749 }
750 
CompileInferenceKernels()751 int TrainSession::CompileInferenceKernels() {
752   inference_kernels_.clear();
753   for (auto item : eval_output_node_map_) {
754     std::string kernel_name = item.first;
755     auto kernel = TSFindKernel(train_kernels_, kernel_name);
756     if (kernel == nullptr) {
757       MS_LOG(ERROR) << "kernel is nullptr";
758       return RET_ERROR;
759     }
760     BuildInferenceKernelsRecursive(kernel, &inference_kernels_);
761   }
762 
763   if (train_kernels_.size() == inference_kernels_.size()) {
764     MS_LOG(WARNING) << "This is inference model, return err in TrainSession.";
765     return RET_ERROR;
766   }
767 
768   if (inference_kernels_.size() == 0) {
769     inference_kernels_ = this->train_kernels_;
770   }
771   return RET_OK;
772 }
773 
CompileOptimizedKernels()774 void TrainSession::CompileOptimizedKernels() {
775   std::vector<lite::Tensor *> out_tensor;
776   for (auto kernel : this->train_kernels_) {
777     if (IsOptimizer(kernel)) {
778       std::copy(kernel->in_tensors().begin(), kernel->in_tensors().end(), std::back_inserter(out_tensor));
779       if (cfg_.accumulate_gradients_) {
780         auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
781         auto ret = optimizer->SetOptimizerMode(kernel::WeightUpdateMode::ACCUMULATE_GRADS);
782         if (ret != RET_OK) {
783           MS_LOG(ERROR) << "SetOptimizerMode failed.";
784           return;
785         }
786       }
787     }
788   }
789 
790   for (auto kernel : this->train_kernels_) {
791     if (!IsOptimizer(kernel)) {
792       for (auto it : kernel->in_tensors()) {
793         if (std::find(out_tensor.begin(), out_tensor.end(), it) != out_tensor.end()) {
794           kernel->SetTrainable(true);
795           break;
796         }
797       }
798     }
799   }
800 }
801 
CompileConstFoldedKernels()802 int TrainSession::CompileConstFoldedKernels() {
803   const_output_tensors_.clear();
804   for (auto kernel : this->inference_kernels_) {
805     bool is_input_const = true;
806     for (auto input : kernel->in_tensors()) {
807       if ((!input->IsConst() || input->IsGraphInput()) &&
808           std::find(const_output_tensors_.begin(), const_output_tensors_.end(), input) == const_output_tensors_.end()) {
809         is_input_const = false;
810       }
811       if (!is_input_const) {
812         const_fold_kernels_.emplace_back(kernel);
813         break;
814       }
815     }
816     if (is_input_const) {
817       auto ret = kernel->Execute();
818       if (RET_OK != ret) {
819         MS_LOG(ERROR) << "run kernel failed, name: " << kernel->name();
820         return ret;
821       }
822       for (auto output : kernel->out_tensors()) {
823         const_output_tensors_.emplace_back(output);
824         output->set_category(Category::CONST_TENSOR);
825       }
826     }
827   }
828   return RET_OK;
829 }
830 
FindConstFoldedKernels()831 int TrainSession::FindConstFoldedKernels() {
832   float obf_ratio = ModelRecoverObfuscate(false);
833   if (!FloatCompare(obf_ratio, 1.0)) {
834     MS_LOG(INFO) << "obfuscated model do not need const folding.";
835     const_fold_kernels_ = this->inference_kernels_;
836     const_output_tensors_ = {};
837     return RET_OK;
838   }
839   const_fold_kernels_.clear();
840   const_output_tensors_.clear();
841   for (auto kernel : this->inference_kernels_) {
842     bool is_input_const = true;
843     for (auto input : kernel->in_tensors()) {
844       if ((!input->IsConst() || input->IsGraphInput()) &&
845           std::find(const_output_tensors_.begin(), const_output_tensors_.end(), input) == const_output_tensors_.end()) {
846         is_input_const = false;
847       }
848       if (!is_input_const) {
849         const_fold_kernels_.emplace_back(kernel);
850         break;
851       }
852     }
853     if (is_input_const) {
854       auto ret = kernel->Execute();
855       if (RET_OK != ret) {
856         MS_LOG(ERROR) << "run kernel failed, name: " << kernel->name();
857         return ret;
858       }
859       for (auto output : kernel->out_tensors()) {
860         output->set_category(Category::CONST_TENSOR);
861         const_output_tensors_.emplace_back(output);
862       }
863     }
864   }
865   return RET_OK;
866 }
867 
CompileTrainableParams()868 void TrainSession::CompileTrainableParams() {
869   for (auto kernel : this->train_kernels_) {
870     if (!IsOptimizer(kernel)) {
871       continue;
872     }
873     auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
874     auto params = optimizer->GetTrainableParams();
875     auto in_kernels = kernel->in_kernels();
876     AddNonConstTrainableParams(in_kernels, optimizer, &params);
877 
878     for (auto param : params) {
879       if (std::find(trainable_parameters_.begin(), trainable_parameters_.end(), param) != trainable_parameters_.end()) {
880         continue;
881       }
882       trainable_parameters_.emplace_back(param);
883     }
884   }
885 }
886 
SetLearningRate(float learning_rate)887 int TrainSession::SetLearningRate(float learning_rate) {
888   if (learning_rate < 0.0f) {
889     MS_LOG(ERROR) << "learning rate should more than 0";
890     return RET_ERROR;
891   }
892   for (auto kernel : this->train_kernels_) {
893     if (IsOptimizer(kernel)) {
894       auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
895       auto ret = optimizer->SetLearningRate(learning_rate);
896       if (ret != RET_OK) {
897         MS_LOG(ERROR) << kernel->name() << " failed to set learning rate";
898         return RET_ERROR;
899       }
900     }
901   }
902   return RET_OK;
903 }
904 
GetLearningRate()905 float TrainSession::GetLearningRate() {
906   for (auto kernel : this->train_kernels_) {
907     if (IsOptimizer(kernel)) {
908       auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
909       return optimizer->GetLearningRate();
910     }
911   }
912   return 0.0;
913 }
914 
GetOptimizerParams() const915 std::vector<lite::Tensor *> TrainSession::GetOptimizerParams() const {
916   std::vector<lite::Tensor *> params;
917   for (auto kernel : this->train_kernels_) {
918     if (IsOptimizer(kernel)) {
919       auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
920       auto kernelParams = optimizer->GetOptimizerParams();
921       for (size_t ix = 0; ix < kernelParams.size(); ix++) {
922         auto kernelParam = kernelParams[ix];
923         auto name = kernelParam->tensor_name();
924         bool found = false;
925         for (size_t iy = 0; iy < params.size(); iy++) {
926           if (params[iy]->tensor_name() == name) {
927             found = true;
928             break;
929           }
930         }
931         if (!found) {
932           params.push_back(kernelParam);
933         }
934       }
935     }
936   }
937   return params;
938 }
939 
SetOptimizerParams(const std::vector<lite::Tensor * > & params)940 int TrainSession::SetOptimizerParams(const std::vector<lite::Tensor *> &params) {
941   for (size_t ix = 0; ix < params.size(); ix++) {
942     auto param = params[ix];
943     if (param == nullptr) {
944       MS_LOG(ERROR) << "Param tensor is null.";
945       return RET_ERROR;
946     }
947     bool found = false;
948     for (auto kernel : this->train_kernels_) {
949       if (IsOptimizer(kernel)) {
950         auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
951         found = optimizer->SetOptimizerParams(param);
952         if (found) break;
953       }
954     }
955     if (!found) {
956       MS_LOG(ERROR) << "Tensor " << param->tensor_name() << " with " << param->ElementsNum() << " elelmts and type "
957                     << param->data_type() << " is not a valid params tensor";
958       return RET_ERROR;
959     }
960   }
961   return RET_OK;
962 }
963 
GetGradients() const964 std::vector<lite::Tensor *> TrainSession::GetGradients() const {
965   std::vector<lite::Tensor *> gradients;
966   for (auto kernel : this->train_kernels_) {
967     if (IsOptimizer(kernel)) {
968       auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
969       auto kernelGradint = optimizer->GetGradients();
970       if (kernelGradint != nullptr) {
971         gradients.push_back(kernelGradint);
972       }
973     }
974   }
975   return gradients;
976 }
977 
ApplyGradients(const std::vector<lite::Tensor * > & gradients)978 int TrainSession::ApplyGradients(const std::vector<lite::Tensor *> &gradients) {
979   auto current_gradients = GetGradients();
980   if (current_gradients.size() != gradients.size()) {
981     MS_LOG(ERROR) << "gradients vector has wrong size " << gradients.size() << " instead of "
982                   << current_gradients.size();
983     FreeGradients(current_gradients);
984     return RET_ERROR;
985   }
986   for (size_t ix = 0; ix < gradients.size(); ix++) {
987     auto gradient = gradients[ix];
988     if (gradient == nullptr) {
989       MS_LOG(ERROR) << "gradient tensor is null.";
990       FreeGradients(current_gradients);
991       return RET_ERROR;
992     }
993     bool found = false;
994     for (size_t iy = 0; iy < current_gradients.size(); iy++) {
995       auto current_gradient = current_gradients[iy];
996       if (current_gradient->tensor_name() == gradient->tensor_name()) {
997         found = true;
998         if (current_gradient->Size() == gradient->Size()) {
999           std::copy(static_cast<uint8_t *>(gradient->data()),
1000                     static_cast<uint8_t *>(gradient->data()) + gradient->Size(),
1001                     static_cast<uint8_t *>(current_gradient->MutableData()));
1002         } else {
1003           MS_LOG(ERROR) << "gradient tensor " << gradient->tensor_name() << " has wrong size " << gradient->Size()
1004                         << " instead of " << current_gradient->Size();
1005           FreeGradients(current_gradients);
1006           return RET_ERROR;
1007         }
1008         break;
1009       }
1010     }
1011     if (!found) {
1012       MS_LOG(ERROR) << "gradient tensor " << gradient->tensor_name() << " not found";
1013       FreeGradients(current_gradients);
1014       return RET_ERROR;
1015     }
1016   }
1017   FreeGradients(current_gradients);
1018   for (auto kernel : this->train_kernels_) {
1019     if (IsOptimizer(kernel)) {
1020       auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
1021       if (optimizer->set_grad_sum_valid() != RET_OK) {
1022         MS_LOG(ERROR) << "set grad sum valid failed.";
1023         return RET_ERROR;
1024       }
1025       auto ret = optimizer->OptimizerStep();
1026       if (ret != RET_OK) {
1027         MS_LOG(ERROR) << "failed to optimize model weights";
1028         return ret;
1029       }
1030     }
1031   }
1032   return RET_OK;
1033 }
1034 
AdminSetupVirtualBatch(int virtual_batch_multiplier,float lr,float momentum)1035 int TrainSession::AdminSetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum) {
1036   auto mod =
1037     (virtual_batch_multiplier <= 1) ? kernel::WeightUpdateMode::NORMAL : kernel::WeightUpdateMode::VIRTUAL_BATCH;
1038   virtual_batch_multiplier_ = (virtual_batch_multiplier <= 1) ? 0 : virtual_batch_multiplier;
1039   virtual_batch_idx_ = 0;
1040 
1041   for (auto kernel : this->train_kernels_) {
1042     if (IsOptimizer(kernel)) {
1043       auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
1044       if (optimizer->get_optimizer_mode() != kernel::WeightUpdateMode::NORMAL &&
1045           optimizer->get_optimizer_mode() != kernel::WeightUpdateMode::VIRTUAL_BATCH) {
1046         MS_LOG(ERROR) << kernel->name() << " failed to set optimizer mode, conflict with accumulate grads";
1047         return RET_ERROR;
1048       }
1049       auto ret = optimizer->SetOptimizerMode(mod);
1050       if (ret != RET_OK) {
1051         MS_LOG(ERROR) << kernel->name() << " failed to set optimizer mode";
1052         return RET_ERROR;
1053       }
1054       if (mod == kernel::WeightUpdateMode::VIRTUAL_BATCH) {
1055         lr = (lr < 0.0f) ? (optimizer->GetLearningRate() / static_cast<float>(virtual_batch_multiplier_)) : lr;
1056         ret = optimizer->SetLearningRate(lr);
1057       } else {
1058         ret = optimizer->RestoreDefaultLearningRate();
1059       }
1060       if (ret != RET_OK) {
1061         MS_LOG(ERROR) << kernel->name() << " failed to set learning rate";
1062         return RET_ERROR;
1063       }
1064     }
1065 
1066     if (IsBN(kernel) && kernel->IsTrainable()) {
1067       auto batchnorm = static_cast<kernel::LiteKernel *>(kernel->kernel());
1068       auto ret = batchnorm->SetupVirtualBatch(virtual_batch_multiplier_, momentum);
1069       if (ret != RET_OK) {
1070         MS_LOG(ERROR) << kernel->name() << " failed to set momentum";
1071         return RET_ERROR;
1072       }
1073     }
1074   }
1075   return RET_OK;
1076 }
SetupVirtualBatch(int virtual_batch_multiplier,float lr,float momentum)1077 int TrainSession::SetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum) {
1078   int tmp = (virtual_batch_multiplier <= 1) ? 0 : virtual_batch_multiplier;
1079   if (tmp != 0 && virtual_batch_multiplier_ != 0) {
1080     AdminSetupVirtualBatch(0, lr, momentum);
1081   }
1082   return AdminSetupVirtualBatch(virtual_batch_multiplier, lr, momentum);
1083 }
1084 
OptimizerStep()1085 int TrainSession::OptimizerStep() {
1086   for (auto kernel : this->train_kernels_) {
1087     if (IsOptimizer(kernel)) {
1088       auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
1089       auto ret = optimizer->OptimizerStep();
1090       if (ret != RET_OK) {
1091         MS_LOG(ERROR) << kernel->name() << " failed to do optimize step";
1092         return RET_ERROR;
1093       }
1094     }
1095   }
1096   return RET_OK;
1097 }
1098 
IsLossInKernel(const kernel::KernelExec * kernel) const1099 bool TrainSession::IsLossInKernel(const kernel::KernelExec *kernel) const {
1100   if (IsLossKernel(kernel) || IsGradKernel(kernel)) {
1101     return true;
1102   }
1103   for (auto in_kernel : kernel->in_kernels()) {
1104     if (IsLossKernel(in_kernel)) {
1105       return true;
1106     }
1107   }
1108   return false;
1109 }
1110 
IsLossKernel(const kernel::KernelExec * kernel) const1111 bool TrainSession::IsLossKernel(const kernel::KernelExec *kernel) const {
1112   bool isLoss = false;
1113   for (auto &s : cfg_.loss_name_) {
1114     if (kernel->name().find(s) != std::string::npos) {
1115       isLoss = true;
1116       break;
1117     }
1118   }
1119   return isLoss;
1120 }
1121 
IsGradKernel(const kernel::KernelExec * kernel) const1122 bool TrainSession::IsGradKernel(const kernel::KernelExec *kernel) const {
1123   return kernel->name().find(kGradName) != std::string::npos;
1124 }
1125 
IsOptimizer(kernel::KernelExec * kernel) const1126 bool TrainSession::IsOptimizer(kernel::KernelExec *kernel) const {
1127   return ((kernel->type() == schema::PrimitiveType_Adam) || (kernel->type() == schema::PrimitiveType_SGD) ||
1128           (kernel->type() == schema::PrimitiveType_ApplyMomentum));
1129 }
1130 
IsMaskOutput(kernel::KernelExec * kernel) const1131 bool TrainSession::IsMaskOutput(kernel::KernelExec *kernel) const {
1132   return (IsOptimizer(kernel) || (kernel->type() == schema::PrimitiveType_Assign));
1133 }
1134 
IsBN(kernel::KernelExec * kernel) const1135 bool TrainSession::IsBN(kernel::KernelExec *kernel) const {
1136   return ((kernel->type() == schema::PrimitiveType_BatchNorm) ||
1137           (kernel->type() == schema::PrimitiveType_FusedBatchNorm));
1138 }
1139 
Resize(const std::vector<lite::Tensor * > & inputs,const std::vector<std::vector<int>> & dims)1140 int TrainSession::Resize(const std::vector<lite::Tensor *> &inputs, const std::vector<std::vector<int>> &dims) {
1141   FreeWorkSpace();
1142   if (tensors_data_ != nullptr) {
1143     free(tensors_data_);
1144     tensors_data_ = nullptr;
1145   }
1146   auto ret = lite::LiteSession::Resize(inputs, dims);
1147   if (ret != RET_OK) {
1148     MS_LOG(ERROR) << "train resize input failed.";
1149     return RET_ERROR;
1150   }
1151   ret = AllocWorkSpace();
1152   if (ret != RET_OK) {
1153     MS_LOG(ERROR) << "failed to allocate space";
1154     return RET_ERROR;
1155   }
1156   ret = AllocTensors(train_kernels_);
1157   if (ret != RET_OK) {
1158     MS_LOG(ERROR) << "train alloc failed after resize.";
1159     return RET_ERROR;
1160   }
1161   return RET_OK;
1162 }
1163 
FindUseInTensorKernel(std::vector<kernel::KernelExec * > * use_in_tensor_kernels,const std::vector<lite::Tensor * > & kernel_in_tensors,const std::vector<kernel::KernelExec * > & inference_kernels)1164 int TrainSession::FindUseInTensorKernel(std::vector<kernel::KernelExec *> *use_in_tensor_kernels,
1165                                         const std::vector<lite::Tensor *> &kernel_in_tensors,
1166                                         const std::vector<kernel::KernelExec *> &inference_kernels) {
1167   for (size_t i = 0; i < inference_kernels.size(); i++) {
1168     for (size_t j = 0; j < kernel_in_tensors.size(); j++) {
1169       if (IsContain(inference_kernels[i]->out_tensors(), kernel_in_tensors[j])) {
1170         use_in_tensor_kernels->push_back(inference_kernels[i]);
1171       }
1172     }
1173   }
1174   return RET_OK;
1175 }
1176 
FindExportKernels(std::vector<kernel::KernelExec * > * export_kernels,const std::vector<std::string> & export_output_tensor_names,const std::vector<kernel::KernelExec * > & inference_kernels)1177 int TrainSession::FindExportKernels(std::vector<kernel::KernelExec *> *export_kernels,
1178                                     const std::vector<std::string> &export_output_tensor_names,
1179                                     const std::vector<kernel::KernelExec *> &inference_kernels) {
1180   std::vector<std::string> all_kernel_name = {};
1181   (void)std::transform(inference_kernels.begin(), inference_kernels.end(), std::back_inserter(all_kernel_name),
1182                        [](kernel::KernelExec *kernel) { return kernel->name(); });
1183   std::queue<std::string> need_kernel_names;
1184   // Find the kernel name according to the tensor name
1185   for (auto &kernel : inference_kernels) {
1186     if (std::any_of(kernel->out_tensors().begin(), kernel->out_tensors().end(), [&](lite::Tensor *out_tensor) {
1187           return IsContain(export_output_tensor_names, out_tensor->tensor_name());
1188         })) {
1189       need_kernel_names.push(kernel->name());
1190     }
1191   }
1192   if (need_kernel_names.size() == 0) {
1193     MS_LOG(ERROR) << "can not find tensor";
1194     return RET_ERROR;
1195   }
1196   // find all kernel
1197   while (!need_kernel_names.empty()) {
1198     auto kernel_name = need_kernel_names.front();
1199     need_kernel_names.pop();
1200     auto it = find(all_kernel_name.begin(), all_kernel_name.end(), kernel_name);
1201     if (it == all_kernel_name.end()) {
1202       MS_LOG(ERROR) << "not find kernel name in export trained model.";
1203       return RET_ERROR;
1204     }
1205     auto kernel = inference_kernels[it - all_kernel_name.begin()];
1206     if (!IsContain(*export_kernels, kernel)) {
1207       export_kernels->push_back(kernel);
1208     }
1209     auto kernel_in_tensors = kernel->in_tensors();
1210     std::vector<kernel::KernelExec *> use_in_tensor_kernels;
1211     auto status = FindUseInTensorKernel(&use_in_tensor_kernels, kernel_in_tensors, inference_kernels);
1212     if (status != RET_OK) {
1213       MS_LOG(ERROR) << "FindUseInTensorKernel failed.";
1214       return RET_ERROR;
1215     }
1216     for (size_t i = 0; i < use_in_tensor_kernels.size(); i++) {
1217       need_kernel_names.push(use_in_tensor_kernels[i]->name());
1218     }
1219   }
1220   return RET_OK;
1221 }
1222 
1223 template <typename DestType>
ExportInner(DestType destination,ModelType model_type,QuantizationType quant_type,FormatType format,std::vector<std::string> out_put_tensor_name)1224 int TrainSession::ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type,
1225                               FormatType format, std::vector<std::string> out_put_tensor_name) {
1226   if constexpr (std::is_same_v<DestType, const std::string &>) {
1227     MS_CHECK_FALSE_MSG(destination.empty(), RET_ERROR, "File name cannot be empty");
1228     struct stat path_type;
1229     if (stat(destination.c_str(), &path_type) == RET_OK) {
1230       if (path_type.st_mode & S_IFDIR) {
1231         MS_LOG(ERROR) << "Destination must be path, now is a directory";
1232         return RET_ERROR;
1233       }
1234     }
1235   } else if constexpr (std::is_same_v<DestType, Buffer *>) {
1236     MS_CHECK_FALSE_MSG(destination == nullptr, RET_ERROR, "model buffer cannot be nullptr");
1237   } else {
1238     MS_LOG(ERROR) << "Unsupported destination.";
1239     return RET_ERROR;
1240   }
1241   MS_CHECK_FALSE_MSG(model_type > mindspore::lite::MT_INFERENCE || model_type < mindspore::lite::MT_TRAIN, RET_ERROR,
1242                      "Export model type parameter error");
1243   MS_CHECK_FALSE_MSG(quant_type < mindspore::lite::QT_DEFAULT || quant_type > mindspore::lite::QT_WEIGHT, RET_ERROR,
1244                      "Export quant type parameter error");
1245   MS_CHECK_FALSE_MSG(format != FT_FLATBUFFERS, RET_ERROR, "File name cannot be empty");
1246 
1247   bool orig_train_state = IsTrain();
1248   TrainExport texport(destination);
1249   int status = texport.ExportInit(model_.get()->graph_.name_, model_.get()->graph_.version_);
1250   TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to init export");
1251 
1252   if (!out_put_tensor_name.empty() && model_type == MT_INFERENCE) {
1253     std::vector<kernel::KernelExec *> export_kernels = {};
1254     status = FindExportKernels(&export_kernels, out_put_tensor_name, const_fold_kernels_);
1255     TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "FindExportKernels failed.");
1256     status =
1257         texport.ExportNet(export_kernels, tensors_, const_output_tensors_, out_put_tensor_name, model_.get(), quant_type);
1258   } else {
1259     if ((!model_buff_changed_) && (quant_type == QT_NONE) && (model_type == MT_TRAIN) &&
1260         std::all_of(model_->graph_.all_nodes_.begin(), model_->graph_.all_nodes_.end(), [](const LiteGraph::Node *n) {
1261           return n->quant_type_ == schema::QuantType::QuantType_QUANT_NONE;
1262         })) {
1263       status = texport.SaveModel(model_.get(), destination);
1264       if (orig_train_state) Train();
1265       return status;
1266     } else {
1267       if (quant_type == QT_NONE) {
1268         status = texport.ExportNet(
1269             (model_type == MT_TRAIN) ? train_kernels_ : const_fold_kernels_, tensors_, const_output_tensors_,
1270             (model_type == MT_TRAIN) ? train_output_tensor_names_ : eval_output_tensor_names_, model_.get(), quant_type);
1271       } else {
1272         status = texport.ExportNet((model_type == MT_TRAIN) ? train_kernels_ : inference_kernels_, tensors_, {},
1273                                    (model_type == MT_TRAIN) ? train_output_tensor_names_ : eval_output_tensor_names_,
1274                                    model_.get(), quant_type);
1275       }
1276     }
1277   }
1278   TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to export Network.");
1279 
1280   if (model_type == MT_INFERENCE) {
1281     status = texport.TrainModelDrop();
1282     TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelDrop failed.");
1283     status = texport.TrainModelFusion();
1284     TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelFusion failed.");
1285   }
1286   if constexpr (std::is_same_v<DestType, const std::string &>) {
1287     status = texport.SaveToFile();
1288     if (status != RET_OK) {
1289       MS_LOG(ERROR) << "failed to save to " << destination;
1290       return status;
1291     }
1292   } else {
1293     status = texport.SaveToBuffer();
1294     TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "fail to save to model buffer.");
1295   }
1296   if (orig_train_state) Train();
1297   return status;
1298 }
1299 
Export(const std::string & file_name,ModelType model_type,QuantizationType quant_type,FormatType format,std::vector<std::string> out_put_tensor_name)1300 int TrainSession::Export(const std::string &file_name, ModelType model_type, QuantizationType quant_type,
1301                          FormatType format, std::vector<std::string> out_put_tensor_name) {
1302   return ExportInner<const std::string &>(file_name, model_type, quant_type, format, out_put_tensor_name);
1303 }
1304 
Export(Buffer * model_buffer,ModelType model_type,QuantizationType quant_type,FormatType format,std::vector<std::string> out_put_tensor_name)1305 int TrainSession::Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType format,
1306                          std::vector<std::string> out_put_tensor_name) {
1307   return ExportInner<Buffer *>(model_buffer, model_type, quant_type, format, out_put_tensor_name);
1308 }
1309 
ExportWeightsCollaborateWithMicro(const std::string & file_name,lite::ModelType model_type,FormatType format,bool enable_fp16,const std::vector<std::string> & changeable_weights_name)1310 int TrainSession::ExportWeightsCollaborateWithMicro(const std::string &file_name, lite::ModelType model_type,
1311                                                     FormatType format, bool enable_fp16,
1312                                                     const std::vector<std::string> &changeable_weights_name) {
1313   MS_CHECK_FALSE_MSG(file_name.empty(), RET_ERROR, "File name cannot be empty");
1314   struct stat path_type;
1315   if (stat(file_name.c_str(), &path_type) == RET_OK) {
1316     if (path_type.st_mode & S_IFDIR) {
1317       MS_LOG(ERROR) << "Destination must be path, now is a directory";
1318       return RET_ERROR;
1319     }
1320   }
1321   MS_CHECK_FALSE_MSG(format != FT_FLATBUFFERS, RET_ERROR, "File name cannot be empty");
1322   MS_CHECK_FALSE_MSG(model_type != mindspore::lite::MT_INFERENCE, RET_ERROR,
1323                      "Currently, can only export inference-model's weights.");
1324 
1325   TrainExport texport(file_name);
1326   auto status = texport.ExportInit(model_.get()->graph_.name_, model_.get()->graph_.version_);
1327   TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to init export");
1328   // Find and prepare a list of kernels which are const folded
1329   MS_CHECK_TRUE_MSG(FindConstFoldedKernels() == RET_OK, RET_ERROR, "FindConstFoldedKernels failed.");
1330   status = texport.ExportNet(const_fold_kernels_, tensors_, const_output_tensors_, eval_output_tensor_names_,
1331                              model_.get(), QT_DEFAULT);
1332   TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to export Network.");
1333   status = texport.TrainModelDrop();
1334   TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelDrop failed.");
1335   status = texport.TrainModelFusion();
1336   TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelFusion failed.");
1337   status = texport.SaveWeightsToFile(enable_fp16, changeable_weights_name);
1338   if (status != RET_OK) {
1339     MS_LOG(ERROR) << "failed to save to " << file_name;
1340     return status;
1341   }
1342   return RET_OK;
1343 }
1344 
GetFeatureMaps() const1345 std::vector<lite::Tensor *> TrainSession::GetFeatureMaps() const {
1346   std::vector<lite::Tensor *> features;
1347   for (auto cur_tensor : this->tensors_) {
1348     if (cur_tensor->category() == lite::Category::CONST_TENSOR && cur_tensor->data_type() == kNumberTypeFloat32) {
1349       features.push_back(cur_tensor);
1350     }
1351   }
1352   return features;
1353 }
1354 
GetTrainableParams() const1355 std::vector<lite::Tensor *> TrainSession::GetTrainableParams() const { return trainable_parameters_; }
1356 
UpdateFeatureMaps(const std::vector<lite::Tensor * > & features_map)1357 int TrainSession::UpdateFeatureMaps(const std::vector<lite::Tensor *> &features_map) {
1358   for (auto feature : features_map) {
1359     bool find = false;
1360     for (auto tensor : tensors_) {
1361       if (!tensor->IsConst() || tensor->data_type() != kNumberTypeFloat32) {
1362         continue;
1363       }
1364       if (feature->tensor_name() != tensor->tensor_name()) {
1365         continue;
1366       }
1367       if (feature->Size() != tensor->Size()) {
1368         MS_LOG(ERROR) << "feature name:" << feature->tensor_name() << ",len diff:"
1369                       << "old is:" << tensor->Size() << "new is:" << feature->Size();
1370         return RET_ERROR;
1371       }
1372       find = true;
1373       memcpy(tensor->data(), feature->data(), tensor->Size());
1374     }
1375     if (!find) {
1376       MS_LOG(ERROR) << "cannot find feature:" << feature->tensor_name() << ",update failed";
1377       return RET_ERROR;
1378     }
1379   }
1380   return RET_OK;
1381 }
1382 
1383 std::set<schema::PrimitiveType> inPlaceSupportedKernels = {
1384   mindspore::schema::PrimitiveType_Activation, mindspore::schema::PrimitiveType_ActivationGrad,
1385   mindspore::schema::PrimitiveType_Reshape,    mindspore::schema::PrimitiveType_DivFusion,
1386   mindspore::schema::PrimitiveType_AddFusion,  mindspore::schema::PrimitiveType_SubFusion,
1387   mindspore::schema::PrimitiveType_RealDiv,    mindspore::schema::PrimitiveType_Select,
1388   mindspore::schema::PrimitiveType_BiasAdd,    mindspore::schema::PrimitiveType_BiasAddGrad,
1389   mindspore::schema::PrimitiveType_Sqrt,       mindspore::schema::PrimitiveType_Abs,
1390   mindspore::schema::PrimitiveType_Cos,        mindspore::schema::PrimitiveType_Log,
1391   mindspore::schema::PrimitiveType_Square,     mindspore::schema::PrimitiveType_Rsqrt,
1392   mindspore::schema::PrimitiveType_Sin,        mindspore::schema::PrimitiveType_LogicalNot,
1393   mindspore::schema::PrimitiveType_LogicalAnd, mindspore::schema::PrimitiveType_LogicalOr,
1394   mindspore::schema::PrimitiveType_Floor,      mindspore::schema::PrimitiveType_Ceil,
1395   mindspore::schema::PrimitiveType_Round,      mindspore::schema::PrimitiveType_Neg,
1396   mindspore::schema::PrimitiveType_Reciprocal, mindspore::schema::PrimitiveType_Erf,
1397   mindspore::schema::PrimitiveType_Maximum,    mindspore::schema::PrimitiveType_Minimum,
1398   mindspore::schema::PrimitiveType_FloorDiv,   mindspore::schema::PrimitiveType_FloorMod,
1399   mindspore::schema::PrimitiveType_Eltwise,    mindspore::schema::PrimitiveType_SquaredDifference,
1400   mindspore::schema::PrimitiveType_ExpandDims, mindspore::schema::PrimitiveType_Cast,
1401   mindspore::schema::PrimitiveType_Flatten,    mindspore::schema::PrimitiveType_FlattenGrad,
1402   mindspore::schema::PrimitiveType_Squeeze,    mindspore::schema::PrimitiveType_Unsqueeze};
1403 
IsInPlaceKernel(kernel::KernelExec * kernel)1404 bool TrainSession::IsInPlaceKernel(kernel::KernelExec *kernel) {
1405   if (inPlaceSupportedKernels.find(kernel->type()) != inPlaceSupportedKernels.end() &&
1406       !(kernel->type() == mindspore::schema::PrimitiveType_Activation &&
1407         kernel->op_parameter()->type_ == schema::ActivationType_SIGMOID)) {
1408     return true;
1409   }
1410   return false;
1411 }
1412 
IsInPlaceTensor(kernel::KernelExec * kernel,uint32_t idx,const std::unordered_map<lite::Tensor *,int> & ref_count,uint32_t * input_idx)1413 bool TrainSession::IsInPlaceTensor(kernel::KernelExec *kernel, uint32_t idx,
1414                                    const std::unordered_map<lite::Tensor *, int> &ref_count, uint32_t *input_idx) {
1415   if (IsInPlaceKernel(kernel)) {
1416     auto out_tensor = kernel->out_tensors().at(idx);
1417     for (size_t i = 0; i < kernel->in_tensors().size(); i++) {
1418       auto tensor = kernel->in_tensors().at(i);
1419       if ((tensor->category() == lite::Category::VAR) && (ref_count.find(tensor) != ref_count.end()) &&
1420           (tensor->init_ref_count() == 1 || (tensor->init_ref_count() > 1 && ref_count.at(tensor) == 1)) &&
1421           (out_tensor->Size() == tensor->Size())) {
1422         *input_idx = static_cast<uint32_t>(i);
1423         return true;
1424       }
1425     }
1426   }
1427   return false;
1428 }
1429 
GetInplaceTensorOffset(kernel::KernelExec * kernel,const std::unordered_map<lite::Tensor *,size_t> & offset_map,std::unordered_map<lite::Tensor *,int> * ref_count,uint32_t input_idx)1430 size_t TrainSession::GetInplaceTensorOffset(kernel::KernelExec *kernel,
1431                                             const std::unordered_map<lite::Tensor *, size_t> &offset_map,
1432                                             std::unordered_map<lite::Tensor *, int> *ref_count, uint32_t input_idx) {
1433   auto tensor = kernel->in_tensors().at(input_idx);
1434   ref_count->at(tensor) = ref_count->at(tensor) + 1;
1435   return offset_map.at(tensor);
1436 }
1437 
FindObfTensor()1438 lite::Tensor *TrainSession::FindObfTensor() {
1439   for (auto node : model_->graph_.all_nodes_) {
1440     if (node->name_.find(kObfNodeName) != std::string::npos) {
1441       auto idx = node->input_indices_[kDataIndex];
1442       return tensors_[idx];
1443     }
1444   }
1445   return nullptr;
1446 }
1447 
ChangeObfWeight(std::string tensor_name,float obf_ratio)1448 int TrainSession::ChangeObfWeight(std::string tensor_name, float obf_ratio) {
1449   float data[1] = {obf_ratio};
1450   auto new_tensor = lite::Tensor::CreateTensor(tensor_name, TypeId::kNumberTypeFloat32, {1, 1}, data, kFloatSize);
1451   std::vector<lite::Tensor *> modify_tensors;
1452   if (new_tensor == nullptr) {
1453     MS_LOG(ERROR) << "Create tensor failed";
1454     return RET_ERROR;
1455   }
1456   modify_tensors.emplace_back(new_tensor);
1457   auto ret = this->UpdateWeights(modify_tensors);
1458   if (ret != kSuccess) {
1459     MS_LOG(ERROR) << "UpdateWeights failed.";
1460     return RET_ERROR;
1461   }
1462   return RET_OK;
1463 }
1464 
ModelRecoverObfuscate(bool change_weight)1465 float TrainSession::ModelRecoverObfuscate(bool change_weight) {
1466   float true_obf_ratio = 1.0;
1467   auto tensor = FindObfTensor();
1468   if (tensor != nullptr) {
1469     std::string tensor_name = tensor->tensor_name();
1470     true_obf_ratio = *(reinterpret_cast<float *>(tensor->data()));
1471     if (!change_weight) {
1472       return true_obf_ratio;
1473     }
1474     float init_obf_ratio = 1.0;
1475     ChangeObfWeight(tensor_name, init_obf_ratio);
1476   }
1477   return true_obf_ratio;
1478 }
1479 
ModelDeObfuscate(float obf_ratio)1480 int TrainSession::ModelDeObfuscate(float obf_ratio) {
1481   if (!FloatCompare(obf_ratio, 0.0)) {
1482     auto *tensor = FindObfTensor();
1483     if (tensor != nullptr) {
1484       std::string tensor_name = tensor->tensor_name();
1485       return ChangeObfWeight(tensor_name, obf_ratio);
1486     }
1487     MS_LOG(ERROR) << "Obfuscate tensor is null";
1488     return RET_ERROR;
1489   }
1490   MS_LOG(ERROR) << "Obfuscate value is 0";
1491   return RET_ERROR;
1492 }
1493 }  // namespace lite
1494 }  // namespace mindspore
1495