• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/train/train_session.h"
18 #include <sys/stat.h>
19 #include <algorithm>
20 #include <utility>
21 #include <vector>
22 #include <iostream>
23 #include <fstream>
24 #include <memory>
25 #include <queue>
26 #include <map>
27 #include "include/errorcode.h"
28 #include "src/executor.h"
29 #include "src/lite_model.h"
30 #include "src/lite_kernel_util.h"
31 #include "src/sub_graph_kernel.h"
32 #include "src/tensor.h"
33 #include "src/kernel_registry.h"
34 #include "src/common/prim_util.h"
35 #include "src/common/tensor_util.h"
36 #include "src/common/utils.h"
37 #include "src/runtime/kernel/arm/fp32_grad/convolution.h"
38 #include "src/runtime/kernel/arm/fp32/batchnorm_fp32.h"
39 #include "src/train/loss_kernel.h"
40 #include "src/train/optimizer_kernel.h"
41 #include "src/train/train_utils.h"
42 #include "src/train/train_export.h"
43 #include "src/train/opt_allocator.h"
44 #include "src/train/static_allocator.h"
45 #include "src/train/train_populate_parameter.h"
46 #include "src/train/train_populate_parameter_v0.h"
47 
48 namespace mindspore {
49 namespace lite {
50 const char *kGradName = "Gradients";
51 const char *kOptimizerName = "optimizer";
52 
TrainSession()53 TrainSession::TrainSession() {
54   is_train_session_ = true;
55   InitCallBack();
56 }
57 
Init(InnerContext * context,const TrainCfg * train_cfg)58 int TrainSession::Init(InnerContext *context, const TrainCfg *train_cfg) {
59   if (train_cfg != nullptr) {
60     if (train_cfg->mix_precision_cfg_.loss_scale_ <= 0) {
61       MS_LOG(ERROR) << "illegal loss scale configuration";
62       return RET_NULL_PTR;
63     }
64     cfg_ = *train_cfg;
65   }
66   if (context == nullptr) {
67     MS_LOG(ERROR) << "context cannot be nullptr";
68     return RET_NULL_PTR;
69   }
70   allocator_ = context->allocator;
71   return lite::LiteSession::Init(context);
72 }
73 
ReplaceOps()74 std::vector<CreatorOp> TrainSession::ReplaceOps() {
75   const std::vector<CreatorOp> replace = {
76     // currently no ops are Hijacked by TrainSession
77   };
78   mindspore::lite::KernelRegistry *reg = mindspore::lite::KernelRegistry::GetInstance();
79   std::vector<CreatorOp> results;
80   for (auto v : replace) {
81     const CreatorOp cl = make_tuple(std::get<0>(v), reg->GetCreator(std::get<0>(v)));
82     results.push_back(cl);
83     reg->RegKernel(std::get<0>(v), std::get<1>(v));
84   }
85   return results;
86 }
87 
RestoreOps(const std::vector<CreatorOp> & restore)88 void TrainSession::RestoreOps(const std::vector<CreatorOp> &restore) {
89   mindspore::lite::KernelRegistry *reg = mindspore::lite::KernelRegistry::GetInstance();
90   for (auto v : restore) {
91     reg->RegKernel(std::get<0>(v), std::get<1>(v));
92   }
93 }
94 
AllocWorkSpace()95 int TrainSession::AllocWorkSpace() {
96   size_t workspace_size = 0;
97   for (auto kernel : this->train_kernels_) {
98     if (workspace_size < static_cast<kernel::InnerKernel *>(kernel->kernel())->workspace_size()) {
99       workspace_size = static_cast<kernel::InnerKernel *>(kernel->kernel())->workspace_size();
100     }
101   }
102   workspace_ = malloc(workspace_size);
103   if (workspace_ == nullptr) {
104     MS_LOG(ERROR) << "cannot allocate " << workspace_size << " for workspace";
105     return RET_ERROR;
106   }
107   for (auto kernel : this->train_kernels_) {
108     static_cast<kernel::InnerKernel *>(kernel->kernel())->set_workspace(workspace_);
109   }
110   return RET_OK;
111 }
112 
FreeWorkSpace()113 void TrainSession::FreeWorkSpace() {
114   if (workspace_ != nullptr) {
115     free(workspace_);
116     workspace_ = nullptr;
117   }
118   for (auto kernel : this->train_kernels_) {
119     static_cast<kernel::InnerKernel *>(kernel->kernel())->FreeWorkspace();
120   }
121 }
122 
InitCallBack()123 int TrainSession::InitCallBack() {
124   sched_mix_precision_callback_ = [&](const Model::Node *node) {
125     if (!context_->IsCpuFloat16Enabled()) {
126       return false;
127     }
128     if (cfg_.mix_precision_cfg_.is_raw_mix_precision_) {
129       auto out_tensor_indexs = node->output_indices_;
130       if (out_tensor_indexs.empty()) {
131         MS_LOG(DEBUG) << "Debug: " << node->name_ << " fp32";
132         return false;
133       }
134       auto is_fp16 = model_->all_tensors_.at(out_tensor_indexs[0])->dataType() == kNumberTypeFloat16;
135       MS_LOG(DEBUG) << "Debug: " << node->name_ << ((is_fp16) ? " fp16" : " fp32");
136       return is_fp16;
137     }
138     auto node_type = GetPrimitiveType(node->primitive_, SCHEMA_VERSION::SCHEMA_CUR);
139     if (node_type == schema::PrimitiveType_Cast) {
140       return false;
141     }
142     auto in_size = node->input_indices_.size();
143     bool force_fp16 = false;
144     for (std::size_t k = 0; k < in_size; k++) {
145       schema::Tensor *tensor = model_->all_tensors_.at(node->input_indices_[k]);
146       if ((tensor->dataType() == kNumberTypeFloat16) && (tensor->nodeType() == NodeType_ValueNode)) {
147         force_fp16 = true;
148         break;
149       }
150     }
151     const auto &node_name = node->name_;
152     bool is_fp16 = true;
153     if (!force_fp16) {
154       // optimizer runs in fp32
155       if (node_name.find(kOptimizerName) != std::string::npos) {
156         is_fp16 = false;
157       }
158       // loss function runs in fp32
159       if ((node_name.find(get_loss_name()) != std::string::npos)) {
160         is_fp16 = false;
161       }
162       // run bn according to user configuration
163       if ((cfg_.mix_precision_cfg_.keep_batchnorm_fp32_) &&
164           (node_type == schema::PrimitiveType_FusedBatchNorm || node_type == schema::PrimitiveType_BatchNorm ||
165            node_type == schema::PrimitiveType_BatchNormGrad)) {
166         is_fp16 = false;
167       }
168     }
169     MS_LOG(DEBUG) << "Debug: " << node_name << ((is_fp16) ? " fp16" : " fp32");
170     return is_fp16;
171   };
172   return RET_OK;
173 }
174 
AllocTensors(const std::vector<kernel::LiteKernel * > & kernels)175 int TrainSession::AllocTensors(const std::vector<kernel::LiteKernel *> &kernels) {
176   if (!IS_STATIC_ALLOCATOR(allocator_)) return RET_OK;
177   OptAllocator allocator;
178   std::unordered_map<lite::Tensor *, int> ref_count;
179   std::unordered_map<lite::Tensor *, size_t> offset_map;
180   for (auto kernel : kernels) {
181     for (auto tensor : kernel->out_tensors()) {
182       size_t size = tensor->Size();
183       size_t offset = allocator.Malloc(size);
184       offset_map[tensor] = offset;
185       ref_count[tensor] = tensor->init_ref_count();
186     }
187     for (auto tensor : kernel->in_tensors()) {
188       if (tensor->category() == lite::Tensor::VAR) {
189         int count = ref_count[tensor] - 1;
190         ref_count[tensor] = count;
191         if (count == 0) {
192           allocator.Free(offset_map[tensor]);
193         }
194       }
195     }
196   }
197   // Set Tensor data
198   if (tensors_data_ == nullptr) {
199     auto size = allocator.total_size();
200     auto buf = malloc(size);
201     if (buf == nullptr) {
202       MS_LOG(ERROR) << "cannot allocate buffer size" << size;
203       return RET_ERROR;
204     }
205     StaticAllocator *alloc = reinterpret_cast<StaticAllocator *>(allocator_.get());
206     alloc->SetContex(buf, size);
207     tensors_data_ = buf;
208   }
209   for (auto kernel : train_kernels_) {
210     for (auto tensor : kernel->out_tensors()) {
211       auto it = offset_map.find(tensor);
212       if (it != offset_map.end()) {
213         tensor->set_data(reinterpret_cast<void *>(reinterpret_cast<char *>(tensors_data_) + it->second));
214       }
215     }
216   }
217   return RET_OK;
218 }
219 
CompileGraph(lite::Model * model)220 int TrainSession::CompileGraph(lite::Model *model) { return lite::RET_ERROR; }
221 
CompileTrainGraph(std::shared_ptr<Model> model)222 int TrainSession::CompileTrainGraph(std::shared_ptr<Model> model) {
223   model_ = model;
224   auto restore = ReplaceOps();
225   sched_cb_ = std::make_unique<SchedulerCb>(sched_mix_precision_callback_);
226   if (sched_cb_ == nullptr) {
227     MS_LOG(ERROR) << "Failed to create SchedulerCb node";
228     return RET_ERROR;
229   }
230 
231 #ifdef ENABLE_V0
232   if (reinterpret_cast<LiteModel *>(model_.get())->GetSchemaVersion() == SCHEMA_VERSION::SCHEMA_V0) {
233     kernel::PopulateTrainV0Parameters();
234   }
235 #endif
236   if (reinterpret_cast<LiteModel *>(model_.get())->GetSchemaVersion() == SCHEMA_VERSION::SCHEMA_CUR) {
237     kernel::PopulateTrainParameters();
238   }
239 
240   auto ret = lite::LiteSession::CompileGraph(model_.get());
241   if (ret != RET_OK) {
242     MS_LOG(ERROR) << "failed to compile train model";
243     return RET_ERROR;
244   }
245   orig_output_node_map_ = output_node_map_;
246   orig_output_tensor_map_ = output_tensor_map_;
247   orig_output_tensor_names_ = output_tensor_names_;
248   for (auto inTensor : inputs_) inTensor->MutableData();
249   RestoreOps(restore);
250   CompileTrainKernels();      // Prepare a list of train kernels
251   CompileOptimizedKernels();  // Prepare a list of kernels which are optimized (weight update step)
252   CompileTrainOutputs();      // prepare outputs in train mode
253   CompileEvalOutputs();       // prepare outputs in eval mode
254   // Prepare a list of eval kernels
255   if (CompileInferenceKernels() != RET_OK) {
256     MS_LOG(ERROR) << "CompileInferenceKernels failed.";
257     return RET_ERROR;
258   }
259   ret = AllocWorkSpace();
260   if (ret != RET_OK) {
261     MS_LOG(ERROR) << "failed to allocate space";
262     return RET_ERROR;
263   }
264   ret = AllocTensors(train_kernels_);
265   if (ret != RET_OK) {
266     MS_LOG(ERROR) << "failed to allocate space";
267     return RET_ERROR;
268   }
269   return RET_OK;
270 }
271 
~TrainSession()272 TrainSession::~TrainSession() {
273   FreeWorkSpace();
274   if (tensors_data_ != nullptr) {
275     free(tensors_data_);
276     tensors_data_ = nullptr;
277   }
278 }
279 
ExecKernels(const KernelCallBack & before,const KernelCallBack & after,const std::vector<kernel::LiteKernel * > & run_kernels)280 int TrainSession::ExecKernels(const KernelCallBack &before, const KernelCallBack &after,
281                               const std::vector<kernel::LiteKernel *> &run_kernels) {
282   for (auto *kernel : run_kernels) {
283     MS_ASSERT(kernel != nullptr);
284     auto ret = kernel->Execute(before, after);
285     if (RET_OK != ret) {
286       MS_LOG(ERROR) << "Execute kernel failed, name: " << kernel->name();
287       return ret;
288     }
289   }
290   return RET_OK;
291 }
292 
RestoreTensorData()293 void TrainSession::RestoreTensorData() {
294   for (auto &restored_origin_tensor : restored_origin_tensors_) {
295     auto *origin_tensor = restored_origin_tensor.first;
296     auto *restored_tensor = restored_origin_tensor.second;
297     MS_ASSERT(origin_tensor != nullptr);
298     MS_ASSERT(restored_tensor != nullptr);
299 
300     bool own_data = restored_tensor->own_data();
301     if (origin_tensor->data() == nullptr) {
302       restored_tensor->FreeData();
303     } else {
304       origin_tensor->FreeData();
305     }
306     origin_tensor->set_data_type(restored_tensor->data_type());
307     origin_tensor->set_data(restored_tensor->data());
308     origin_tensor->set_own_data(own_data);
309   }
310 }
311 
FreeRestoreTensors()312 void TrainSession::FreeRestoreTensors() {
313   for (auto &restored_origin_tensor : restored_origin_tensors_) {
314     auto *restored_tensor = restored_origin_tensor.second;
315     restored_tensor->set_data(nullptr);
316     delete (restored_tensor);
317   }
318   restored_origin_tensors_.clear();
319 }
320 
IsLossTensor(Tensor * tensor)321 bool TrainSession::IsLossTensor(Tensor *tensor) {
322   MS_ASSERT(tensor != nullptr);
323   auto t_n = tensor->tensor_name();
324   return (t_n.find(get_loss_name()) != std::string::npos);
325 }
326 
AllInputsNeedScale(kernel::LiteKernel * kernel)327 bool TrainSession::AllInputsNeedScale(kernel::LiteKernel *kernel) {
328   auto type = kernel->type();
329   bool is_scale = false;
330   switch (type) {
331     case schema::PrimitiveType_AbsGrad:
332     case schema::PrimitiveType_AddFusion:
333     case schema::PrimitiveType_SubFusion:
334     case schema::PrimitiveType_AddN:
335       for (auto &tensor : kernel->in_tensors()) {
336         is_scale = is_scale || tensor->IsScale();
337       }
338       return (is_scale);
339     default:
340       return false;
341   }
342   return false;
343 }
344 
MixPrecisionPreProcess(kernel::LiteKernel * kernel,float scale)345 int TrainSession::MixPrecisionPreProcess(kernel::LiteKernel *kernel, float scale) {
346   auto kernel_type = kernel->desc().data_type;
347   auto all_scale = AllInputsNeedScale(kernel);
348 
349   for (auto &tensor : kernel->in_tensors()) {
350     if ((tensor->IsScale() == false) && ((!IsLossKernel(kernel) && IsLossTensor(tensor)) || (all_scale == true))) {
351       ScaleTensor(tensor, scale);
352     }
353     // adjust tensor data type
354     if (tensor->data_type() != kernel_type) {
355       auto restore_tensor = CastTensor(tensor, kernel_type, this->context_->device_and_pkg_support_fp16());
356       if (restore_tensor != nullptr) {
357         restored_origin_tensors_[tensor] = restore_tensor;
358       }
359     }
360   }
361   return RET_OK;
362 }
363 
MixPrecisionPostProcess(kernel::LiteKernel * kernel)364 int TrainSession::MixPrecisionPostProcess(kernel::LiteKernel *kernel) {
365   RestoreTensorData();
366   FreeRestoreTensors();
367 
368   float scale = 1.0f;
369   auto all_scale = AllInputsNeedScale(kernel);
370   for (auto &tensor : kernel->in_tensors()) {
371     if (tensor->IsScale()) {
372       scale *= tensor->get_scale();
373       if (all_scale) {
374         break;
375       }
376     }
377   }
378   for (auto &tensor : kernel->out_tensors()) {
379     tensor->set_scale(scale);
380   }
381 
382   for (auto &tensor : kernel->in_tensors()) {
383     if ((tensor->IsScale() == true) && ((!IsLossKernel(kernel) && IsLossTensor(tensor)) || (all_scale == true))) {
384       ScaleTensor(tensor, 1.0f / scale);
385     }
386   }
387   return RET_OK;
388 }
389 
MixPrecisionExecKernels(const KernelCallBack & before,const KernelCallBack & after,const std::vector<kernel::LiteKernel * > & run_kernels)390 int TrainSession::MixPrecisionExecKernels(const KernelCallBack &before, const KernelCallBack &after,
391                                           const std::vector<kernel::LiteKernel *> &run_kernels) {
392   float scale = cfg_.mix_precision_cfg_.loss_scale_;
393   for (auto *kernel : run_kernels) {
394     MS_ASSERT(kernel != nullptr);
395     MixPrecisionPreProcess(kernel, scale);
396     auto ret = kernel->Execute(before, after);
397     if (RET_OK != ret) {
398       MixPrecisionPostProcess(kernel);
399       // decrease loss scale in case of nan or inf
400       if (ret == RET_OUT_OF_TENSOR_RANGE) {
401         bool is_dynamic_scale = cfg_.mix_precision_cfg_.dynamic_loss_scale_;
402         cfg_.mix_precision_cfg_.loss_scale_ = std::max(((is_dynamic_scale) ? (scale / 2.f) : scale), 1.0f);
403         num_of_not_nan_iter_ = 0;
404         return RET_OK;
405       }
406       MS_LOG(ERROR) << "Execute kernel failed, name: " << kernel->name();
407       return ret;
408     }
409     MixPrecisionPostProcess(kernel);
410   }
411   // increase dynamic loss scale if pass pass threshold
412   if (cfg_.mix_precision_cfg_.dynamic_loss_scale_) {
413     num_of_not_nan_iter_++;
414     if (num_of_not_nan_iter_ >= cfg_.mix_precision_cfg_.num_of_not_nan_iter_th_) {
415       cfg_.mix_precision_cfg_.loss_scale_ = std::min((cfg_.mix_precision_cfg_.loss_scale_ * 2.0f), 65536.0f);
416       num_of_not_nan_iter_ = 0;
417     }
418   }
419 
420   // cast output to FP32
421   if (train_mode_ == false) {
422     for (auto t : this->outputs_) {
423       if (t->data_type() == kNumberTypeFloat16) {
424         auto restore = CastTensor(t, kNumberTypeFloat32, this->context_->device_and_pkg_support_fp16());
425         delete restore;
426       }
427     }
428   }
429   return RET_OK;
430 }
431 
RunGraph(const KernelCallBack & before,const KernelCallBack & after)432 int TrainSession::RunGraph(const KernelCallBack &before, const KernelCallBack &after) {
433   // check inputs
434   auto ret = CheckTensorsInvalid(inputs_);
435   if (ret != RET_OK) {
436     MS_LOG(ERROR) << "CheckInputs failed";
437     return ret;
438   }
439 
440   // build out tensor
441   this->outputs_.clear();
442   for (auto &ms_tensors : output_node_map_) {
443     for (auto &ms_tensor : ms_tensors.second) {
444       auto lite_tensor = static_cast<lite::Tensor *>(ms_tensor);
445       this->outputs_.push_back(lite_tensor);
446     }
447   }
448 
449   if (this->context_ == nullptr) {
450     MS_LOG(ERROR) << "context is null";
451     return lite::RET_NULL_PTR;
452   }
453   auto &run_kernels = (train_mode_) ? train_kernels_ : inference_kernels_;
454   if (context_->IsCpuFloat16Enabled() && !cfg_.mix_precision_cfg_.is_raw_mix_precision_) {
455     ret = MixPrecisionExecKernels(before, after, run_kernels);
456   } else {
457     ret = ExecKernels(before, after, run_kernels);
458   }
459   if (ret != RET_OK) {
460     MS_LOG(ERROR) << "failed to run model kernels";
461     return ret;
462   }
463 
464   if (train_mode_ && virtual_batch_multiplier_) {
465     virtual_batch_idx_++;
466     if (virtual_batch_idx_ >= virtual_batch_multiplier_) {
467       virtual_batch_idx_ = 0;
468       ret = OptimizerStep();
469       if (ret != RET_OK) {
470         MS_LOG(ERROR) << "failed to optimize model weights";
471         return ret;
472       }
473     }
474   }
475   return RET_OK;
476 }
477 
Train()478 int TrainSession::Train() {
479   // shift kernels to train mode
480   train_mode_ = true;
481   virtual_batch_idx_ = 0;
482   for (auto &kernel : this->train_kernels_) {
483     MS_ASSERT(kernel != nullptr);
484     auto ret = kernel->Train();
485     if (ret != RET_OK) {
486       MS_LOG(ERROR) << kernel->name() << " failed to set train mode";
487       return RET_ERROR;
488     }
489   }
490   // set train outputs
491   output_node_map_ = train_output_node_map_;
492   output_tensor_map_ = train_output_tensor_map_;
493   output_tensor_names_ = train_output_tensor_names_;
494   kernel::LiteKernelUtil::InitTensorInitRefCount(train_kernels_);
495   for (auto &ms_tensors : eval_output_node_map_) {  // Allow to look at prediction also during training
496     for (auto &ms_tensor : ms_tensors.second) {
497       lite::Tensor *lite_tensor = static_cast<lite::Tensor *>(ms_tensor);
498       lite_tensor->set_init_ref_count(lite_tensor->init_ref_count() + 1);
499     }
500   }
501   // allocate tensors
502   auto ret = AllocTensors(train_kernels_);
503   if (ret != RET_OK) {
504     MS_LOG(ERROR) << "failed to allocate tensor space";
505     return RET_ERROR;
506   }
507   return RET_OK;
508 }
509 
Eval()510 int TrainSession::Eval() {
511   // shift kernels to eval mode
512   train_mode_ = false;
513   virtual_batch_idx_ = 0;
514   for (auto &kernel : this->train_kernels_) {
515     MS_ASSERT(kernel != nullptr);
516     auto ret = kernel->Eval();
517     if (ret != RET_OK) {
518       MS_LOG(ERROR) << kernel->name() << " failed to set eval mode";
519       return RET_ERROR;
520     }
521   }
522   // set eval outputs
523   output_node_map_ = eval_output_node_map_;
524   output_tensor_map_ = eval_output_tensor_map_;
525   output_tensor_names_ = eval_output_tensor_names_;
526   kernel::LiteKernelUtil::InitTensorInitRefCount(inference_kernels_);
527   for (auto &ms_tensors : eval_output_node_map_) {
528     for (auto &ms_tensor : ms_tensors.second) {
529       lite::Tensor *lite_tensor = static_cast<lite::Tensor *>(ms_tensor);
530       lite_tensor->set_init_ref_count(lite_tensor->init_ref_count() + 1);
531     }
532   }
533   auto ret = AllocTensors(inference_kernels_);
534   if (ret != RET_OK) {
535     MS_LOG(ERROR) << "failed to allocate space";
536     return RET_ERROR;
537   }
538   return RET_OK;
539 }
540 
CompileEvalOutputs()541 void TrainSession::CompileEvalOutputs() {
542   eval_output_node_map_.clear();
543   eval_output_tensor_map_.clear();
544   eval_output_tensor_names_.clear();
545   for (auto kernel : this->train_kernels_) {
546     if (IsLossKernel(kernel) && !(IsGradKernel(kernel))) {
547       for (auto in_kernel : kernel->in_kernels()) {
548         if (IsLossKernel(in_kernel) || IsGradKernel(in_kernel)) continue;
549         // insert if not already in
550         if (eval_output_node_map_.find(in_kernel->name()) == eval_output_node_map_.end()) {
551           auto *ms_tensor = in_kernel->out_tensors().at(0);
552           if (ms_tensor != nullptr) {
553             ms_tensor->set_init_ref_count(ms_tensor->init_ref_count() + 1);
554             eval_output_node_map_[in_kernel->name()].emplace_back(ms_tensor);
555             auto index = TSFindTensor(tensors_, ms_tensor);
556             if (index != tensors_.size()) {
557               eval_output_tensor_map_.insert(std::make_pair(std::to_string(index), ms_tensor));
558               if (!ms_tensor->tensor_name().empty()) {
559                 eval_output_tensor_names_.emplace_back(ms_tensor->tensor_name());
560               } else {
561                 eval_output_tensor_names_.emplace_back(std::to_string(index));
562               }
563             }
564           }
565         }
566       }
567     }
568   }
569   if (eval_output_node_map_.size() == 0) eval_output_node_map_ = orig_output_node_map_;
570   if (eval_output_tensor_map_.size() == 0) eval_output_tensor_map_ = orig_output_tensor_map_;
571   if (eval_output_tensor_names_.size() == 0) eval_output_tensor_names_ = orig_output_tensor_names_;
572 }
573 
CompileTrainOutputs()574 void TrainSession::CompileTrainOutputs() {
575   train_output_node_map_.clear();
576   train_output_tensor_map_.clear();
577   train_output_tensor_names_.clear();
578   for (auto kernel : this->train_kernels_) {
579     if (orig_output_node_map_.find(kernel->name()) == orig_output_node_map_.end()) continue;
580     // Mask out optimizer out tensors
581     if (IsMaskOutput(kernel)) continue;
582     // insert if not already in
583     if (train_output_node_map_.find(kernel->name()) == train_output_node_map_.end()) {
584       auto *ms_tensor = kernel->out_tensors().at(0);
585       if (ms_tensor != nullptr) {
586         train_output_node_map_[kernel->name()].emplace_back(ms_tensor);
587         auto index = TSFindTensor(tensors_, ms_tensor);
588         if (index != tensors_.size()) {
589           train_output_tensor_map_.insert(std::make_pair(std::to_string(index), ms_tensor));
590           if (!ms_tensor->tensor_name().empty()) {
591             train_output_tensor_names_.emplace_back(ms_tensor->tensor_name());
592           } else {
593             train_output_tensor_names_.emplace_back(std::to_string(index));
594           }
595         }
596       }
597     }
598   }
599   if (train_output_node_map_.size() == 0) train_output_node_map_ = orig_output_node_map_;
600   if (train_output_tensor_map_.size() == 0) train_output_tensor_map_ = orig_output_tensor_map_;
601   if (train_output_tensor_names_.size() == 0) train_output_tensor_names_ = orig_output_tensor_names_;
602 }
603 
BuildInferenceKernelsRecursive(kernel::LiteKernel * kernel,std::vector<kernel::LiteKernel * > * v)604 void TrainSession::BuildInferenceKernelsRecursive(kernel::LiteKernel *kernel, std::vector<kernel::LiteKernel *> *v) {
605   MS_ASSERT(kernel != nullptr);
606   MS_ASSERT(v != nullptr);
607   if (std::find(v->begin(), v->end(), kernel) == v->end()) {  // kernel is not already in vector
608     for (auto in_node : kernel->in_kernels()) {
609       BuildInferenceKernelsRecursive(in_node, v);
610     }
611     if (!IsLossKernel(kernel)) v->push_back(kernel);
612   }
613 }
614 
CompileTrainKernels()615 void TrainSession::CompileTrainKernels() {
616   train_kernels_.clear();
617   for (auto ori_kernel : kernels_) {
618     if (ori_kernel->subgraph_type() == kernel::kNotSubGraph) {
619       train_kernels_.push_back(ori_kernel);
620     } else {
621       auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(ori_kernel);
622       for (auto kernel : sub_graph->nodes()) {
623         train_kernels_.push_back(kernel);
624       }
625     }
626   }
627 }
628 
CompileInferenceKernels()629 int TrainSession::CompileInferenceKernels() {
630   inference_kernels_.clear();
631   for (auto item : eval_output_node_map_) {
632     std::string kernel_name = item.first;
633     auto kernel = TSFindKernel(train_kernels_, kernel_name);
634     if (kernel == nullptr) {
635       MS_LOG(ERROR) << "kernel is nullptr";
636       return RET_ERROR;
637     }
638     BuildInferenceKernelsRecursive(kernel, &inference_kernels_);
639   }
640   if (inference_kernels_.size() == 0) {
641     inference_kernels_ = this->train_kernels_;
642   }
643   return RET_OK;
644 }
645 
CompileOptimizedKernels()646 void TrainSession::CompileOptimizedKernels() {
647   std::vector<lite::Tensor *> out_tensor;
648   for (auto kernel : this->train_kernels_) {
649     if (IsOptimizer(kernel)) {
650       std::copy(kernel->in_tensors().begin(), kernel->in_tensors().end(), std::back_inserter(out_tensor));
651       if (cfg_.accumulate_gradients_) {
652         auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
653         optimizer->SetOptimizerMode(kernel::WeightUpdateMode::ACCUMULATE_GRADS);
654       }
655     }
656   }
657 
658   for (auto kernel : this->train_kernels_) {
659     if (!IsOptimizer(kernel)) {
660       for (auto it : kernel->in_tensors()) {
661         if (std::find(out_tensor.begin(), out_tensor.end(), it) != out_tensor.end()) {
662           kernel->SetTrainable(true);
663           break;
664         }
665       }
666     }
667   }
668 }
669 
SetLearningRate(float learning_rate)670 int TrainSession::SetLearningRate(float learning_rate) {
671   if (learning_rate < 0.0f) {
672     MS_LOG(ERROR) << "learning rate should more than 0";
673     return RET_ERROR;
674   }
675   for (auto kernel : this->train_kernels_) {
676     if (IsOptimizer(kernel)) {
677       auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
678       auto ret = optimizer->SetLearningRate(learning_rate);
679       if (ret != RET_OK) {
680         MS_LOG(ERROR) << kernel->name() << " failed to set learning rate";
681         return RET_ERROR;
682       }
683     }
684   }
685   return RET_OK;
686 }
687 
GetLearningRate()688 float TrainSession::GetLearningRate() {
689   for (auto kernel : this->train_kernels_) {
690     if (IsOptimizer(kernel)) {
691       auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
692       return optimizer->GetLearningRate();
693     }
694   }
695   return 0.0;
696 }
697 
GetOptimizerParams() const698 std::vector<tensor::MSTensor *> TrainSession::GetOptimizerParams() const {
699   std::vector<tensor::MSTensor *> params;
700   for (auto kernel : this->train_kernels_) {
701     if (IsOptimizer(kernel)) {
702       auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
703       auto kernelParams = optimizer->GetOptimizerParams();
704       for (size_t ix = 0; ix < kernelParams.size(); ix++) {
705         auto kernelParam = kernelParams[ix];
706         auto name = kernelParam->tensor_name();
707         bool found = false;
708         for (size_t iy = 0; iy < params.size(); iy++) {
709           if (params[iy]->tensor_name() == name) {
710             found = true;
711             break;
712           }
713         }
714         if (!found) {
715           params.push_back(kernelParam);
716         }
717       }
718     }
719   }
720   return params;
721 }
722 
SetOptimizerParams(const std::vector<tensor::MSTensor * > & params)723 int TrainSession::SetOptimizerParams(const std::vector<tensor::MSTensor *> &params) {
724   for (size_t ix = 0; ix < params.size(); ix++) {
725     auto param = params[ix];
726     if (param == nullptr) {
727       MS_LOG(ERROR) << "Param tensor " << param->tensor_name() << " is null.";
728       return RET_ERROR;
729     }
730     bool found = false;
731     for (auto kernel : this->train_kernels_) {
732       if (IsOptimizer(kernel)) {
733         auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
734         found = optimizer->SetOptimizerParams(param);
735         if (found) break;
736       }
737     }
738     if (!found) {
739       MS_LOG(ERROR) << "Tensor " << param->tensor_name() << " with " << param->ElementsNum() << " elelmts and type "
740                     << param->data_type() << " is not a valid params tensor";
741       return RET_ERROR;
742     }
743   }
744   return RET_OK;
745 }
746 
GetGradients() const747 std::vector<tensor::MSTensor *> TrainSession::GetGradients() const {
748   std::vector<tensor::MSTensor *> gradients;
749   for (auto kernel : this->train_kernels_) {
750     if (IsOptimizer(kernel)) {
751       auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
752       auto kernelGradint = optimizer->GetGradients();
753       if (kernelGradint != nullptr) {
754         gradients.push_back(kernelGradint);
755       }
756     }
757   }
758   return gradients;
759 }
760 
ApplyGradients(const std::vector<tensor::MSTensor * > & gradients)761 int TrainSession::ApplyGradients(const std::vector<tensor::MSTensor *> &gradients) {
762   auto current_gradients = GetGradients();
763   if (current_gradients.size() != gradients.size()) {
764     MS_LOG(ERROR) << "gradients vector has wrong size " << gradients.size() << " instead of "
765                   << current_gradients.size();
766     return RET_ERROR;
767   }
768   for (size_t ix = 0; ix < gradients.size(); ix++) {
769     auto gradient = gradients[ix];
770     if (gradient == nullptr) {
771       MS_LOG(ERROR) << "gradient tensor " << gradient->tensor_name() << " is null.";
772       return RET_ERROR;
773     }
774     bool found = false;
775     for (size_t iy = 0; iy < current_gradients.size(); iy++) {
776       auto current_gradient = current_gradients[iy];
777       if (current_gradient->tensor_name() == gradient->tensor_name()) {
778         found = true;
779         if (current_gradient->Size() == gradient->Size()) {
780           std::copy(static_cast<char *>(gradient->data()), static_cast<char *>(gradient->data()) + gradient->Size(),
781                     static_cast<char *>(current_gradient->MutableData()));
782         } else {
783           MS_LOG(ERROR) << "gradient tensor " << gradient->tensor_name() << " has wrong size " << gradient->Size()
784                         << " instead of " << current_gradient->Size();
785           return RET_ERROR;
786         }
787         break;
788       }
789     }
790     if (!found) {
791       MS_LOG(ERROR) << "gradient tensor " << gradient->tensor_name() << " not found";
792       return RET_ERROR;
793     }
794   }
795   for (auto kernel : this->train_kernels_) {
796     if (IsOptimizer(kernel)) {
797       auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
798       optimizer->set_grad_sum_valid();
799       auto ret = optimizer->OptimizerStep();
800       if (ret != RET_OK) {
801         MS_LOG(ERROR) << "failed to optimize model weights";
802         return ret;
803       }
804     }
805   }
806   for (size_t ix = 0; ix < current_gradients.size(); ix++) {
807     delete current_gradients[ix];
808   }
809   return RET_OK;
810 }
811 
AdminSetupVirtualBatch(int virtual_batch_multiplier,float lr,float momentum)812 int TrainSession::AdminSetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum) {
813   auto mod =
814     (virtual_batch_multiplier <= 1) ? kernel::WeightUpdateMode::NORMAL : kernel::WeightUpdateMode::VIRTUAL_BATCH;
815   virtual_batch_multiplier_ = (virtual_batch_multiplier <= 1) ? 0 : virtual_batch_multiplier;
816   virtual_batch_idx_ = 0;
817 
818   for (auto kernel : this->train_kernels_) {
819     if (IsOptimizer(kernel)) {
820       auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
821       if (optimizer->get_optimizer_mode() != kernel::WeightUpdateMode::NORMAL &&
822           optimizer->get_optimizer_mode() != kernel::WeightUpdateMode::VIRTUAL_BATCH) {
823         MS_LOG(ERROR) << kernel->name() << " failed to set optimizer mode, conflict with accumulate grads";
824         return RET_ERROR;
825       }
826       auto ret = optimizer->SetOptimizerMode(mod);
827       if (ret != RET_OK) {
828         MS_LOG(ERROR) << kernel->name() << " failed to set optimizer mode";
829         return RET_ERROR;
830       }
831       if (mod == kernel::WeightUpdateMode::VIRTUAL_BATCH) {
832         lr = (lr < 0.0f) ? (optimizer->GetLearningRate() / static_cast<float>(virtual_batch_multiplier_)) : lr;
833         ret = optimizer->SetLearningRate(lr);
834       } else {
835         ret = optimizer->RestoreDefaultLearningRate();
836       }
837       if (ret != RET_OK) {
838         MS_LOG(ERROR) << kernel->name() << " failed to set learning rate";
839         return RET_ERROR;
840       }
841     }
842 
843     if (IsBN(kernel) && kernel->IsTrainable()) {
844       auto batchnorm = static_cast<kernel::BatchnormCPUKernel *>(kernel->kernel());
845       auto ret = RET_OK;
846       if (mod == kernel::WeightUpdateMode::VIRTUAL_BATCH) {
847         momentum = (momentum < 0.0f) ? (batchnorm->get_momentum() / virtual_batch_multiplier_) : momentum;
848         ret = batchnorm->set_momentum(momentum);
849       } else {
850         ret = batchnorm->RestoreDefaultMomentum();
851       }
852       if (ret != RET_OK) {
853         MS_LOG(ERROR) << kernel->name() << " failed to set momentum";
854         return RET_ERROR;
855       }
856     }
857   }
858   return RET_OK;
859 }
SetupVirtualBatch(int virtual_batch_multiplier,float lr,float momentum)860 int TrainSession::SetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum) {
861   int tmp = (virtual_batch_multiplier <= 1) ? 0 : virtual_batch_multiplier;
862   if (tmp != 0 && virtual_batch_multiplier_ != 0) {
863     AdminSetupVirtualBatch(0, lr, momentum);
864   }
865   return AdminSetupVirtualBatch(virtual_batch_multiplier, lr, momentum);
866 }
867 
OptimizerStep()868 int TrainSession::OptimizerStep() {
869   for (auto kernel : this->train_kernels_) {
870     if (IsOptimizer(kernel)) {
871       auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
872       auto ret = optimizer->OptimizerStep();
873       if (ret != RET_OK) {
874         MS_LOG(ERROR) << kernel->name() << " failed to do optimize step";
875         return RET_ERROR;
876       }
877     }
878   }
879   return RET_OK;
880 }
881 
IsLossKernel(const kernel::LiteKernel * kernel) const882 bool TrainSession::IsLossKernel(const kernel::LiteKernel *kernel) const {
883   return (kernel->type() == schema::PrimitiveType_SoftmaxCrossEntropyWithLogits ||
884           kernel->type() == schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits ||
885           kernel->type() == schema::PrimitiveType_SmoothL1Loss ||
886           kernel->type() == schema::PrimitiveType_SmoothL1LossGrad ||
887           kernel->type() == schema::PrimitiveType_SigmoidCrossEntropyWithLogits ||
888           kernel->type() == schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad) ||
889          kernel->name().find(cfg_.loss_name_) != std::string::npos;
890 }
891 
IsGradKernel(const kernel::LiteKernel * kernel) const892 bool TrainSession::IsGradKernel(const kernel::LiteKernel *kernel) const {
893   return kernel->name().find(kGradName) != std::string::npos;
894 }
895 
IsOptimizer(kernel::LiteKernel * kernel) const896 bool TrainSession::IsOptimizer(kernel::LiteKernel *kernel) const {
897   return ((kernel->type() == schema::PrimitiveType_Adam) || (kernel->type() == schema::PrimitiveType_SGD) ||
898           (kernel->type() == schema::PrimitiveType_ApplyMomentum));
899 }
900 
IsMaskOutput(kernel::LiteKernel * kernel) const901 bool TrainSession::IsMaskOutput(kernel::LiteKernel *kernel) const {
902   return (IsOptimizer(kernel) || (kernel->type() == schema::PrimitiveType_Assign));
903 }
904 
IsBN(kernel::LiteKernel * kernel) const905 bool TrainSession::IsBN(kernel::LiteKernel *kernel) const {
906   return ((kernel->type() == schema::PrimitiveType_BatchNorm) ||
907           (kernel->type() == schema::PrimitiveType_FusedBatchNorm));
908 }
909 
Resize(const std::vector<tensor::MSTensor * > & inputs,const std::vector<std::vector<int>> & dims)910 int TrainSession::Resize(const std::vector<tensor::MSTensor *> &inputs, const std::vector<std::vector<int>> &dims) {
911   FreeWorkSpace();
912   if (tensors_data_ != nullptr) {
913     free(tensors_data_);
914     tensors_data_ = nullptr;
915   }
916   auto ret = lite::LiteSession::Resize(inputs, dims);
917   if (ret != RET_OK) {
918     MS_LOG(ERROR) << "train resize input failed.";
919     return RET_ERROR;
920   }
921   ret = AllocWorkSpace();
922   if (ret != RET_OK) {
923     MS_LOG(ERROR) << "failed to allocate space";
924     return RET_ERROR;
925   }
926   ret = AllocTensors(train_kernels_);
927   if (ret != RET_OK) {
928     MS_LOG(ERROR) << "train alloc failed after resize.";
929     return RET_ERROR;
930   }
931   return RET_OK;
932 }
933 
FindUseInTensorKernel(std::vector<kernel::LiteKernel * > * use_in_tensor_kernels,const std::vector<lite::Tensor * > & kernel_in_tensors,const std::vector<kernel::LiteKernel * > & inference_kernels)934 int TrainSession::FindUseInTensorKernel(std::vector<kernel::LiteKernel *> *use_in_tensor_kernels,
935                                         const std::vector<lite::Tensor *> &kernel_in_tensors,
936                                         const std::vector<kernel::LiteKernel *> &inference_kernels) {
937   for (size_t i = 0; i < inference_kernels.size(); i++) {
938     for (size_t j = 0; j < kernel_in_tensors.size(); j++) {
939       if (IsContain(inference_kernels[i]->out_tensors(), kernel_in_tensors[j])) {
940         use_in_tensor_kernels->push_back(inference_kernels[i]);
941       }
942     }
943   }
944   return RET_OK;
945 }
946 
FindExportKernels(std::vector<kernel::LiteKernel * > * export_kernels,const std::vector<std::string> & export_output_tensor_names,const std::vector<kernel::LiteKernel * > & inference_kernels)947 int TrainSession::FindExportKernels(std::vector<kernel::LiteKernel *> *export_kernels,
948                                     const std::vector<std::string> &export_output_tensor_names,
949                                     const std::vector<kernel::LiteKernel *> &inference_kernels) {
950   std::vector<std::string> all_kernel_name = {};
951   std::transform(inference_kernels.begin(), inference_kernels.end(), std::back_inserter(all_kernel_name),
952                  [](kernel::LiteKernel *kernel) { return kernel->name(); });
953   std::queue<std::string> need_kernel_names;
954   // Find the kernel name according to the tensor name
955   for (auto &kernel : inference_kernels) {
956     if (std::any_of(kernel->out_tensors().begin(), kernel->out_tensors().end(), [&](lite::Tensor *out_tensor) {
957           return IsContain(export_output_tensor_names, out_tensor->tensor_name());
958         })) {
959       need_kernel_names.push(kernel->name());
960     }
961   }
962   if (need_kernel_names.size() == 0) {
963     MS_LOG(ERROR) << "can not find tensor";
964     return RET_ERROR;
965   }
966   // find all kernel
967   while (!need_kernel_names.empty()) {
968     auto kernel_name = need_kernel_names.front();
969     need_kernel_names.pop();
970     auto it = find(all_kernel_name.begin(), all_kernel_name.end(), kernel_name);
971     if (it == all_kernel_name.end()) {
972       MS_LOG(ERROR) << "not find kernel name in export trained model.";
973       return RET_ERROR;
974     }
975     auto kernel = inference_kernels[it - all_kernel_name.begin()];
976     if (!IsContain(*export_kernels, kernel)) {
977       export_kernels->push_back(kernel);
978     }
979     auto kernel_in_tensors = kernel->in_tensors();
980     std::vector<kernel::LiteKernel *> use_in_tensor_kernels;
981     auto status = FindUseInTensorKernel(&use_in_tensor_kernels, kernel_in_tensors, inference_kernels);
982     if (status != RET_OK) {
983       MS_LOG(ERROR) << "FindUseInTensorKernel failed.";
984       return RET_ERROR;
985     }
986     for (size_t i = 0; i < use_in_tensor_kernels.size(); i++) {
987       need_kernel_names.push(use_in_tensor_kernels[i]->name());
988     }
989   }
990   return RET_OK;
991 }
992 
Export(const std::string & file_name,ModelType model_type,QuantizationType quant_type,FormatType format,std::vector<std::string> out_put_tensor_name)993 int TrainSession::Export(const std::string &file_name, ModelType model_type, QuantizationType quant_type,
994                          FormatType format, std::vector<std::string> out_put_tensor_name) {
995   if (file_name.empty()) {
996     MS_LOG(ERROR) << "File name cannot be empty";
997     return RET_ERROR;
998   }
999   if (model_type > mindspore::lite::MT_INFERENCE || model_type < mindspore::lite::MT_TRAIN) {
1000     MS_LOG(ERROR) << "Export model type parameter error";
1001     return RET_ERROR;
1002   }
1003   if (quant_type < mindspore::lite::QT_DEFAULT || quant_type > mindspore::lite::QT_WEIGHT) {
1004     MS_LOG(ERROR) << "Export quant type parameter error";
1005     return RET_ERROR;
1006   }
1007   if (format != FT_FLATBUFFERS) {
1008     MS_LOG(ERROR) << "Currently only flatbuffer format is supported";
1009     return RET_ERROR;
1010   }
1011 
1012   bool orig_train_state = IsTrain();
1013   Eval();
1014   TrainExport texport(file_name);
1015   int status = texport.ExportInit(model_.get()->name_, model_.get()->version_);
1016   if (status != RET_OK) {
1017     MS_LOG(ERROR) << "cannot init export";
1018     return status;
1019   }
1020 
1021   if (!out_put_tensor_name.empty() && model_type == MT_INFERENCE) {
1022     std::vector<kernel::LiteKernel *> export_kernels = {};
1023     status = FindExportKernels(&export_kernels, out_put_tensor_name, inference_kernels_);
1024     if (status != RET_OK) {
1025       MS_LOG(ERROR) << "FindExportKernels failed.";
1026       return RET_ERROR;
1027     }
1028     status = texport.ExportNet(export_kernels, tensors_, out_put_tensor_name, model_.get(), quant_type);
1029   } else {
1030     status = texport.ExportNet((model_type == MT_TRAIN) ? train_kernels_ : inference_kernels_, tensors_,
1031                                (model_type == MT_TRAIN) ? train_output_tensor_names_ : eval_output_tensor_names_,
1032                                model_.get(), quant_type);
1033   }
1034 
1035   if (status != RET_OK) {
1036     MS_LOG(ERROR) << "cannot export Network";
1037     return status;
1038   }
1039   status = texport.SaveToFile();
1040   if (status != RET_OK) {
1041     MS_LOG(ERROR) << "failed to save to " << file_name;
1042     return status;
1043   }
1044   if (orig_train_state) Train();
1045   return status;
1046 }
GetFeatureMaps() const1047 std::vector<tensor::MSTensor *> TrainSession::GetFeatureMaps() const {
1048   std::vector<tensor::MSTensor *> features;
1049   for (auto cur_tensor : this->tensors_) {
1050     if (cur_tensor->IsConst() && cur_tensor->data_type() == kNumberTypeFloat32) {
1051       features.push_back(cur_tensor);
1052     }
1053   }
1054   return features;
1055 }
1056 
UpdateFeatureMaps(const std::vector<tensor::MSTensor * > & features_map)1057 int TrainSession::UpdateFeatureMaps(const std::vector<tensor::MSTensor *> &features_map) {
1058   for (auto feature : features_map) {
1059     bool find = false;
1060     for (auto tensor : tensors_) {
1061       if (!tensor->IsConst() || tensor->data_type() != kNumberTypeFloat32) {
1062         continue;
1063       }
1064       if (feature->tensor_name() != tensor->tensor_name()) {
1065         continue;
1066       }
1067       if (feature->Size() != tensor->Size()) {
1068         MS_LOG(ERROR) << "feature name:" << feature->tensor_name() << ",len diff:"
1069                       << "old is:" << tensor->Size() << "new is:" << feature->Size();
1070         return RET_ERROR;
1071       }
1072       find = true;
1073       memcpy(tensor->data(), feature->data(), tensor->Size());
1074     }
1075     if (!find) {
1076       MS_LOG(ERROR) << "cannot find feature:" << feature->tensor_name() << ",update failed";
1077       return RET_ERROR;
1078     }
1079   }
1080   return RET_OK;
1081 }
1082 }  // namespace lite
1083 
CreateTrainSession(const std::string & fn,const lite::Context * context,bool train_mode,const lite::TrainCfg * cfg)1084 session::LiteSession *session::TrainSession::CreateTrainSession(const std::string &fn, const lite::Context *context,
1085                                                                 bool train_mode, const lite::TrainCfg *cfg) {
1086   if (context == nullptr) {
1087     MS_LOG(ERROR) << "context cannot be nullptr";
1088     return nullptr;
1089   }
1090   auto session = std::make_unique<lite::TrainSession>();
1091   if (session == nullptr) {
1092     MS_LOG(ERROR) << "create session failed";
1093     return nullptr;
1094   }
1095   if (context->allocator == nullptr) {
1096     const_cast<lite::Context *>(context)->allocator = std::make_shared<StaticAllocator>();
1097     if (context->allocator == nullptr) {
1098       MS_LOG(ERROR) << " cannot convert to static allocation";
1099     }
1100   }
1101 
1102   auto *inner_context = new (std::nothrow) mindspore::lite::InnerContext(context);
1103   auto ret = session->Init(inner_context, cfg);
1104   if (ret != mindspore::lite::RET_OK) {
1105     MS_LOG(ERROR) << "init session failed";
1106     return nullptr;
1107   }
1108 
1109   std::string filename = fn;
1110   if (filename.substr(filename.find_last_of(".") + 1) != "ms") {
1111     filename = filename + ".ms";
1112   }
1113 
1114   auto model = std::shared_ptr<lite::Model>(lite::Model::Import(filename.c_str()));
1115   if (model == nullptr) {
1116     MS_LOG(ERROR) << "create model for train session failed " << filename;
1117     return nullptr;
1118   }
1119 
1120   ret = session->CompileTrainGraph(model);
1121   if (ret != mindspore::lite::RET_OK) {
1122     MS_LOG(ERROR) << "Compiling Train Graph session failed";
1123     return nullptr;
1124   }
1125 
1126   if (train_mode) {
1127     ret = session->Train();
1128   } else {
1129     ret = session->Eval();
1130   }
1131   if (ret != mindspore::lite::RET_OK) {
1132     MS_LOG(ERROR) << "Could not switch to Train Modei " << train_mode;
1133     return nullptr;
1134   }
1135   return session.release();
1136 }
1137 }  // namespace mindspore
1138