1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/train/train_session.h"
18 #include <sys/stat.h>
19 #include <algorithm>
20 #include <utility>
21 #include <vector>
22 #include <iostream>
23 #include <fstream>
24 #include <memory>
25 #include <queue>
26 #include <map>
27 #include "include/errorcode.h"
28 #include "src/executor.h"
29 #include "src/lite_model.h"
30 #include "src/lite_kernel_util.h"
31 #include "src/sub_graph_kernel.h"
32 #include "src/tensor.h"
33 #include "src/kernel_registry.h"
34 #include "src/common/prim_util.h"
35 #include "src/common/tensor_util.h"
36 #include "src/common/utils.h"
37 #include "src/runtime/kernel/arm/fp32_grad/convolution.h"
38 #include "src/runtime/kernel/arm/fp32/batchnorm_fp32.h"
39 #include "src/train/loss_kernel.h"
40 #include "src/train/optimizer_kernel.h"
41 #include "src/train/train_utils.h"
42 #include "src/train/train_export.h"
43 #include "src/train/opt_allocator.h"
44 #include "src/train/static_allocator.h"
45 #include "src/train/train_populate_parameter.h"
46 #include "src/train/train_populate_parameter_v0.h"
47
48 namespace mindspore {
49 namespace lite {
50 const char *kGradName = "Gradients";
51 const char *kOptimizerName = "optimizer";
52
TrainSession()53 TrainSession::TrainSession() {
54 is_train_session_ = true;
55 InitCallBack();
56 }
57
Init(InnerContext * context,const TrainCfg * train_cfg)58 int TrainSession::Init(InnerContext *context, const TrainCfg *train_cfg) {
59 if (train_cfg != nullptr) {
60 if (train_cfg->mix_precision_cfg_.loss_scale_ <= 0) {
61 MS_LOG(ERROR) << "illegal loss scale configuration";
62 return RET_NULL_PTR;
63 }
64 cfg_ = *train_cfg;
65 }
66 if (context == nullptr) {
67 MS_LOG(ERROR) << "context cannot be nullptr";
68 return RET_NULL_PTR;
69 }
70 allocator_ = context->allocator;
71 return lite::LiteSession::Init(context);
72 }
73
ReplaceOps()74 std::vector<CreatorOp> TrainSession::ReplaceOps() {
75 const std::vector<CreatorOp> replace = {
76 // currently no ops are Hijacked by TrainSession
77 };
78 mindspore::lite::KernelRegistry *reg = mindspore::lite::KernelRegistry::GetInstance();
79 std::vector<CreatorOp> results;
80 for (auto v : replace) {
81 const CreatorOp cl = make_tuple(std::get<0>(v), reg->GetCreator(std::get<0>(v)));
82 results.push_back(cl);
83 reg->RegKernel(std::get<0>(v), std::get<1>(v));
84 }
85 return results;
86 }
87
RestoreOps(const std::vector<CreatorOp> & restore)88 void TrainSession::RestoreOps(const std::vector<CreatorOp> &restore) {
89 mindspore::lite::KernelRegistry *reg = mindspore::lite::KernelRegistry::GetInstance();
90 for (auto v : restore) {
91 reg->RegKernel(std::get<0>(v), std::get<1>(v));
92 }
93 }
94
AllocWorkSpace()95 int TrainSession::AllocWorkSpace() {
96 size_t workspace_size = 0;
97 for (auto kernel : this->train_kernels_) {
98 if (workspace_size < static_cast<kernel::InnerKernel *>(kernel->kernel())->workspace_size()) {
99 workspace_size = static_cast<kernel::InnerKernel *>(kernel->kernel())->workspace_size();
100 }
101 }
102 workspace_ = malloc(workspace_size);
103 if (workspace_ == nullptr) {
104 MS_LOG(ERROR) << "cannot allocate " << workspace_size << " for workspace";
105 return RET_ERROR;
106 }
107 for (auto kernel : this->train_kernels_) {
108 static_cast<kernel::InnerKernel *>(kernel->kernel())->set_workspace(workspace_);
109 }
110 return RET_OK;
111 }
112
FreeWorkSpace()113 void TrainSession::FreeWorkSpace() {
114 if (workspace_ != nullptr) {
115 free(workspace_);
116 workspace_ = nullptr;
117 }
118 for (auto kernel : this->train_kernels_) {
119 static_cast<kernel::InnerKernel *>(kernel->kernel())->FreeWorkspace();
120 }
121 }
122
InitCallBack()123 int TrainSession::InitCallBack() {
124 sched_mix_precision_callback_ = [&](const Model::Node *node) {
125 if (!context_->IsCpuFloat16Enabled()) {
126 return false;
127 }
128 if (cfg_.mix_precision_cfg_.is_raw_mix_precision_) {
129 auto out_tensor_indexs = node->output_indices_;
130 if (out_tensor_indexs.empty()) {
131 MS_LOG(DEBUG) << "Debug: " << node->name_ << " fp32";
132 return false;
133 }
134 auto is_fp16 = model_->all_tensors_.at(out_tensor_indexs[0])->dataType() == kNumberTypeFloat16;
135 MS_LOG(DEBUG) << "Debug: " << node->name_ << ((is_fp16) ? " fp16" : " fp32");
136 return is_fp16;
137 }
138 auto node_type = GetPrimitiveType(node->primitive_, SCHEMA_VERSION::SCHEMA_CUR);
139 if (node_type == schema::PrimitiveType_Cast) {
140 return false;
141 }
142 auto in_size = node->input_indices_.size();
143 bool force_fp16 = false;
144 for (std::size_t k = 0; k < in_size; k++) {
145 schema::Tensor *tensor = model_->all_tensors_.at(node->input_indices_[k]);
146 if ((tensor->dataType() == kNumberTypeFloat16) && (tensor->nodeType() == NodeType_ValueNode)) {
147 force_fp16 = true;
148 break;
149 }
150 }
151 const auto &node_name = node->name_;
152 bool is_fp16 = true;
153 if (!force_fp16) {
154 // optimizer runs in fp32
155 if (node_name.find(kOptimizerName) != std::string::npos) {
156 is_fp16 = false;
157 }
158 // loss function runs in fp32
159 if ((node_name.find(get_loss_name()) != std::string::npos)) {
160 is_fp16 = false;
161 }
162 // run bn according to user configuration
163 if ((cfg_.mix_precision_cfg_.keep_batchnorm_fp32_) &&
164 (node_type == schema::PrimitiveType_FusedBatchNorm || node_type == schema::PrimitiveType_BatchNorm ||
165 node_type == schema::PrimitiveType_BatchNormGrad)) {
166 is_fp16 = false;
167 }
168 }
169 MS_LOG(DEBUG) << "Debug: " << node_name << ((is_fp16) ? " fp16" : " fp32");
170 return is_fp16;
171 };
172 return RET_OK;
173 }
174
AllocTensors(const std::vector<kernel::LiteKernel * > & kernels)175 int TrainSession::AllocTensors(const std::vector<kernel::LiteKernel *> &kernels) {
176 if (!IS_STATIC_ALLOCATOR(allocator_)) return RET_OK;
177 OptAllocator allocator;
178 std::unordered_map<lite::Tensor *, int> ref_count;
179 std::unordered_map<lite::Tensor *, size_t> offset_map;
180 for (auto kernel : kernels) {
181 for (auto tensor : kernel->out_tensors()) {
182 size_t size = tensor->Size();
183 size_t offset = allocator.Malloc(size);
184 offset_map[tensor] = offset;
185 ref_count[tensor] = tensor->init_ref_count();
186 }
187 for (auto tensor : kernel->in_tensors()) {
188 if (tensor->category() == lite::Tensor::VAR) {
189 int count = ref_count[tensor] - 1;
190 ref_count[tensor] = count;
191 if (count == 0) {
192 allocator.Free(offset_map[tensor]);
193 }
194 }
195 }
196 }
197 // Set Tensor data
198 if (tensors_data_ == nullptr) {
199 auto size = allocator.total_size();
200 auto buf = malloc(size);
201 if (buf == nullptr) {
202 MS_LOG(ERROR) << "cannot allocate buffer size" << size;
203 return RET_ERROR;
204 }
205 StaticAllocator *alloc = reinterpret_cast<StaticAllocator *>(allocator_.get());
206 alloc->SetContex(buf, size);
207 tensors_data_ = buf;
208 }
209 for (auto kernel : train_kernels_) {
210 for (auto tensor : kernel->out_tensors()) {
211 auto it = offset_map.find(tensor);
212 if (it != offset_map.end()) {
213 tensor->set_data(reinterpret_cast<void *>(reinterpret_cast<char *>(tensors_data_) + it->second));
214 }
215 }
216 }
217 return RET_OK;
218 }
219
CompileGraph(lite::Model * model)220 int TrainSession::CompileGraph(lite::Model *model) { return lite::RET_ERROR; }
221
CompileTrainGraph(std::shared_ptr<Model> model)222 int TrainSession::CompileTrainGraph(std::shared_ptr<Model> model) {
223 model_ = model;
224 auto restore = ReplaceOps();
225 sched_cb_ = std::make_unique<SchedulerCb>(sched_mix_precision_callback_);
226 if (sched_cb_ == nullptr) {
227 MS_LOG(ERROR) << "Failed to create SchedulerCb node";
228 return RET_ERROR;
229 }
230
231 #ifdef ENABLE_V0
232 if (reinterpret_cast<LiteModel *>(model_.get())->GetSchemaVersion() == SCHEMA_VERSION::SCHEMA_V0) {
233 kernel::PopulateTrainV0Parameters();
234 }
235 #endif
236 if (reinterpret_cast<LiteModel *>(model_.get())->GetSchemaVersion() == SCHEMA_VERSION::SCHEMA_CUR) {
237 kernel::PopulateTrainParameters();
238 }
239
240 auto ret = lite::LiteSession::CompileGraph(model_.get());
241 if (ret != RET_OK) {
242 MS_LOG(ERROR) << "failed to compile train model";
243 return RET_ERROR;
244 }
245 orig_output_node_map_ = output_node_map_;
246 orig_output_tensor_map_ = output_tensor_map_;
247 orig_output_tensor_names_ = output_tensor_names_;
248 for (auto inTensor : inputs_) inTensor->MutableData();
249 RestoreOps(restore);
250 CompileTrainKernels(); // Prepare a list of train kernels
251 CompileOptimizedKernels(); // Prepare a list of kernels which are optimized (weight update step)
252 CompileTrainOutputs(); // prepare outputs in train mode
253 CompileEvalOutputs(); // prepare outputs in eval mode
254 // Prepare a list of eval kernels
255 if (CompileInferenceKernels() != RET_OK) {
256 MS_LOG(ERROR) << "CompileInferenceKernels failed.";
257 return RET_ERROR;
258 }
259 ret = AllocWorkSpace();
260 if (ret != RET_OK) {
261 MS_LOG(ERROR) << "failed to allocate space";
262 return RET_ERROR;
263 }
264 ret = AllocTensors(train_kernels_);
265 if (ret != RET_OK) {
266 MS_LOG(ERROR) << "failed to allocate space";
267 return RET_ERROR;
268 }
269 return RET_OK;
270 }
271
~TrainSession()272 TrainSession::~TrainSession() {
273 FreeWorkSpace();
274 if (tensors_data_ != nullptr) {
275 free(tensors_data_);
276 tensors_data_ = nullptr;
277 }
278 }
279
ExecKernels(const KernelCallBack & before,const KernelCallBack & after,const std::vector<kernel::LiteKernel * > & run_kernels)280 int TrainSession::ExecKernels(const KernelCallBack &before, const KernelCallBack &after,
281 const std::vector<kernel::LiteKernel *> &run_kernels) {
282 for (auto *kernel : run_kernels) {
283 MS_ASSERT(kernel != nullptr);
284 auto ret = kernel->Execute(before, after);
285 if (RET_OK != ret) {
286 MS_LOG(ERROR) << "Execute kernel failed, name: " << kernel->name();
287 return ret;
288 }
289 }
290 return RET_OK;
291 }
292
RestoreTensorData()293 void TrainSession::RestoreTensorData() {
294 for (auto &restored_origin_tensor : restored_origin_tensors_) {
295 auto *origin_tensor = restored_origin_tensor.first;
296 auto *restored_tensor = restored_origin_tensor.second;
297 MS_ASSERT(origin_tensor != nullptr);
298 MS_ASSERT(restored_tensor != nullptr);
299
300 bool own_data = restored_tensor->own_data();
301 if (origin_tensor->data() == nullptr) {
302 restored_tensor->FreeData();
303 } else {
304 origin_tensor->FreeData();
305 }
306 origin_tensor->set_data_type(restored_tensor->data_type());
307 origin_tensor->set_data(restored_tensor->data());
308 origin_tensor->set_own_data(own_data);
309 }
310 }
311
FreeRestoreTensors()312 void TrainSession::FreeRestoreTensors() {
313 for (auto &restored_origin_tensor : restored_origin_tensors_) {
314 auto *restored_tensor = restored_origin_tensor.second;
315 restored_tensor->set_data(nullptr);
316 delete (restored_tensor);
317 }
318 restored_origin_tensors_.clear();
319 }
320
IsLossTensor(Tensor * tensor)321 bool TrainSession::IsLossTensor(Tensor *tensor) {
322 MS_ASSERT(tensor != nullptr);
323 auto t_n = tensor->tensor_name();
324 return (t_n.find(get_loss_name()) != std::string::npos);
325 }
326
AllInputsNeedScale(kernel::LiteKernel * kernel)327 bool TrainSession::AllInputsNeedScale(kernel::LiteKernel *kernel) {
328 auto type = kernel->type();
329 bool is_scale = false;
330 switch (type) {
331 case schema::PrimitiveType_AbsGrad:
332 case schema::PrimitiveType_AddFusion:
333 case schema::PrimitiveType_SubFusion:
334 case schema::PrimitiveType_AddN:
335 for (auto &tensor : kernel->in_tensors()) {
336 is_scale = is_scale || tensor->IsScale();
337 }
338 return (is_scale);
339 default:
340 return false;
341 }
342 return false;
343 }
344
MixPrecisionPreProcess(kernel::LiteKernel * kernel,float scale)345 int TrainSession::MixPrecisionPreProcess(kernel::LiteKernel *kernel, float scale) {
346 auto kernel_type = kernel->desc().data_type;
347 auto all_scale = AllInputsNeedScale(kernel);
348
349 for (auto &tensor : kernel->in_tensors()) {
350 if ((tensor->IsScale() == false) && ((!IsLossKernel(kernel) && IsLossTensor(tensor)) || (all_scale == true))) {
351 ScaleTensor(tensor, scale);
352 }
353 // adjust tensor data type
354 if (tensor->data_type() != kernel_type) {
355 auto restore_tensor = CastTensor(tensor, kernel_type, this->context_->device_and_pkg_support_fp16());
356 if (restore_tensor != nullptr) {
357 restored_origin_tensors_[tensor] = restore_tensor;
358 }
359 }
360 }
361 return RET_OK;
362 }
363
MixPrecisionPostProcess(kernel::LiteKernel * kernel)364 int TrainSession::MixPrecisionPostProcess(kernel::LiteKernel *kernel) {
365 RestoreTensorData();
366 FreeRestoreTensors();
367
368 float scale = 1.0f;
369 auto all_scale = AllInputsNeedScale(kernel);
370 for (auto &tensor : kernel->in_tensors()) {
371 if (tensor->IsScale()) {
372 scale *= tensor->get_scale();
373 if (all_scale) {
374 break;
375 }
376 }
377 }
378 for (auto &tensor : kernel->out_tensors()) {
379 tensor->set_scale(scale);
380 }
381
382 for (auto &tensor : kernel->in_tensors()) {
383 if ((tensor->IsScale() == true) && ((!IsLossKernel(kernel) && IsLossTensor(tensor)) || (all_scale == true))) {
384 ScaleTensor(tensor, 1.0f / scale);
385 }
386 }
387 return RET_OK;
388 }
389
MixPrecisionExecKernels(const KernelCallBack & before,const KernelCallBack & after,const std::vector<kernel::LiteKernel * > & run_kernels)390 int TrainSession::MixPrecisionExecKernels(const KernelCallBack &before, const KernelCallBack &after,
391 const std::vector<kernel::LiteKernel *> &run_kernels) {
392 float scale = cfg_.mix_precision_cfg_.loss_scale_;
393 for (auto *kernel : run_kernels) {
394 MS_ASSERT(kernel != nullptr);
395 MixPrecisionPreProcess(kernel, scale);
396 auto ret = kernel->Execute(before, after);
397 if (RET_OK != ret) {
398 MixPrecisionPostProcess(kernel);
399 // decrease loss scale in case of nan or inf
400 if (ret == RET_OUT_OF_TENSOR_RANGE) {
401 bool is_dynamic_scale = cfg_.mix_precision_cfg_.dynamic_loss_scale_;
402 cfg_.mix_precision_cfg_.loss_scale_ = std::max(((is_dynamic_scale) ? (scale / 2.f) : scale), 1.0f);
403 num_of_not_nan_iter_ = 0;
404 return RET_OK;
405 }
406 MS_LOG(ERROR) << "Execute kernel failed, name: " << kernel->name();
407 return ret;
408 }
409 MixPrecisionPostProcess(kernel);
410 }
411 // increase dynamic loss scale if pass pass threshold
412 if (cfg_.mix_precision_cfg_.dynamic_loss_scale_) {
413 num_of_not_nan_iter_++;
414 if (num_of_not_nan_iter_ >= cfg_.mix_precision_cfg_.num_of_not_nan_iter_th_) {
415 cfg_.mix_precision_cfg_.loss_scale_ = std::min((cfg_.mix_precision_cfg_.loss_scale_ * 2.0f), 65536.0f);
416 num_of_not_nan_iter_ = 0;
417 }
418 }
419
420 // cast output to FP32
421 if (train_mode_ == false) {
422 for (auto t : this->outputs_) {
423 if (t->data_type() == kNumberTypeFloat16) {
424 auto restore = CastTensor(t, kNumberTypeFloat32, this->context_->device_and_pkg_support_fp16());
425 delete restore;
426 }
427 }
428 }
429 return RET_OK;
430 }
431
RunGraph(const KernelCallBack & before,const KernelCallBack & after)432 int TrainSession::RunGraph(const KernelCallBack &before, const KernelCallBack &after) {
433 // check inputs
434 auto ret = CheckTensorsInvalid(inputs_);
435 if (ret != RET_OK) {
436 MS_LOG(ERROR) << "CheckInputs failed";
437 return ret;
438 }
439
440 // build out tensor
441 this->outputs_.clear();
442 for (auto &ms_tensors : output_node_map_) {
443 for (auto &ms_tensor : ms_tensors.second) {
444 auto lite_tensor = static_cast<lite::Tensor *>(ms_tensor);
445 this->outputs_.push_back(lite_tensor);
446 }
447 }
448
449 if (this->context_ == nullptr) {
450 MS_LOG(ERROR) << "context is null";
451 return lite::RET_NULL_PTR;
452 }
453 auto &run_kernels = (train_mode_) ? train_kernels_ : inference_kernels_;
454 if (context_->IsCpuFloat16Enabled() && !cfg_.mix_precision_cfg_.is_raw_mix_precision_) {
455 ret = MixPrecisionExecKernels(before, after, run_kernels);
456 } else {
457 ret = ExecKernels(before, after, run_kernels);
458 }
459 if (ret != RET_OK) {
460 MS_LOG(ERROR) << "failed to run model kernels";
461 return ret;
462 }
463
464 if (train_mode_ && virtual_batch_multiplier_) {
465 virtual_batch_idx_++;
466 if (virtual_batch_idx_ >= virtual_batch_multiplier_) {
467 virtual_batch_idx_ = 0;
468 ret = OptimizerStep();
469 if (ret != RET_OK) {
470 MS_LOG(ERROR) << "failed to optimize model weights";
471 return ret;
472 }
473 }
474 }
475 return RET_OK;
476 }
477
Train()478 int TrainSession::Train() {
479 // shift kernels to train mode
480 train_mode_ = true;
481 virtual_batch_idx_ = 0;
482 for (auto &kernel : this->train_kernels_) {
483 MS_ASSERT(kernel != nullptr);
484 auto ret = kernel->Train();
485 if (ret != RET_OK) {
486 MS_LOG(ERROR) << kernel->name() << " failed to set train mode";
487 return RET_ERROR;
488 }
489 }
490 // set train outputs
491 output_node_map_ = train_output_node_map_;
492 output_tensor_map_ = train_output_tensor_map_;
493 output_tensor_names_ = train_output_tensor_names_;
494 kernel::LiteKernelUtil::InitTensorInitRefCount(train_kernels_);
495 for (auto &ms_tensors : eval_output_node_map_) { // Allow to look at prediction also during training
496 for (auto &ms_tensor : ms_tensors.second) {
497 lite::Tensor *lite_tensor = static_cast<lite::Tensor *>(ms_tensor);
498 lite_tensor->set_init_ref_count(lite_tensor->init_ref_count() + 1);
499 }
500 }
501 // allocate tensors
502 auto ret = AllocTensors(train_kernels_);
503 if (ret != RET_OK) {
504 MS_LOG(ERROR) << "failed to allocate tensor space";
505 return RET_ERROR;
506 }
507 return RET_OK;
508 }
509
Eval()510 int TrainSession::Eval() {
511 // shift kernels to eval mode
512 train_mode_ = false;
513 virtual_batch_idx_ = 0;
514 for (auto &kernel : this->train_kernels_) {
515 MS_ASSERT(kernel != nullptr);
516 auto ret = kernel->Eval();
517 if (ret != RET_OK) {
518 MS_LOG(ERROR) << kernel->name() << " failed to set eval mode";
519 return RET_ERROR;
520 }
521 }
522 // set eval outputs
523 output_node_map_ = eval_output_node_map_;
524 output_tensor_map_ = eval_output_tensor_map_;
525 output_tensor_names_ = eval_output_tensor_names_;
526 kernel::LiteKernelUtil::InitTensorInitRefCount(inference_kernels_);
527 for (auto &ms_tensors : eval_output_node_map_) {
528 for (auto &ms_tensor : ms_tensors.second) {
529 lite::Tensor *lite_tensor = static_cast<lite::Tensor *>(ms_tensor);
530 lite_tensor->set_init_ref_count(lite_tensor->init_ref_count() + 1);
531 }
532 }
533 auto ret = AllocTensors(inference_kernels_);
534 if (ret != RET_OK) {
535 MS_LOG(ERROR) << "failed to allocate space";
536 return RET_ERROR;
537 }
538 return RET_OK;
539 }
540
CompileEvalOutputs()541 void TrainSession::CompileEvalOutputs() {
542 eval_output_node_map_.clear();
543 eval_output_tensor_map_.clear();
544 eval_output_tensor_names_.clear();
545 for (auto kernel : this->train_kernels_) {
546 if (IsLossKernel(kernel) && !(IsGradKernel(kernel))) {
547 for (auto in_kernel : kernel->in_kernels()) {
548 if (IsLossKernel(in_kernel) || IsGradKernel(in_kernel)) continue;
549 // insert if not already in
550 if (eval_output_node_map_.find(in_kernel->name()) == eval_output_node_map_.end()) {
551 auto *ms_tensor = in_kernel->out_tensors().at(0);
552 if (ms_tensor != nullptr) {
553 ms_tensor->set_init_ref_count(ms_tensor->init_ref_count() + 1);
554 eval_output_node_map_[in_kernel->name()].emplace_back(ms_tensor);
555 auto index = TSFindTensor(tensors_, ms_tensor);
556 if (index != tensors_.size()) {
557 eval_output_tensor_map_.insert(std::make_pair(std::to_string(index), ms_tensor));
558 if (!ms_tensor->tensor_name().empty()) {
559 eval_output_tensor_names_.emplace_back(ms_tensor->tensor_name());
560 } else {
561 eval_output_tensor_names_.emplace_back(std::to_string(index));
562 }
563 }
564 }
565 }
566 }
567 }
568 }
569 if (eval_output_node_map_.size() == 0) eval_output_node_map_ = orig_output_node_map_;
570 if (eval_output_tensor_map_.size() == 0) eval_output_tensor_map_ = orig_output_tensor_map_;
571 if (eval_output_tensor_names_.size() == 0) eval_output_tensor_names_ = orig_output_tensor_names_;
572 }
573
CompileTrainOutputs()574 void TrainSession::CompileTrainOutputs() {
575 train_output_node_map_.clear();
576 train_output_tensor_map_.clear();
577 train_output_tensor_names_.clear();
578 for (auto kernel : this->train_kernels_) {
579 if (orig_output_node_map_.find(kernel->name()) == orig_output_node_map_.end()) continue;
580 // Mask out optimizer out tensors
581 if (IsMaskOutput(kernel)) continue;
582 // insert if not already in
583 if (train_output_node_map_.find(kernel->name()) == train_output_node_map_.end()) {
584 auto *ms_tensor = kernel->out_tensors().at(0);
585 if (ms_tensor != nullptr) {
586 train_output_node_map_[kernel->name()].emplace_back(ms_tensor);
587 auto index = TSFindTensor(tensors_, ms_tensor);
588 if (index != tensors_.size()) {
589 train_output_tensor_map_.insert(std::make_pair(std::to_string(index), ms_tensor));
590 if (!ms_tensor->tensor_name().empty()) {
591 train_output_tensor_names_.emplace_back(ms_tensor->tensor_name());
592 } else {
593 train_output_tensor_names_.emplace_back(std::to_string(index));
594 }
595 }
596 }
597 }
598 }
599 if (train_output_node_map_.size() == 0) train_output_node_map_ = orig_output_node_map_;
600 if (train_output_tensor_map_.size() == 0) train_output_tensor_map_ = orig_output_tensor_map_;
601 if (train_output_tensor_names_.size() == 0) train_output_tensor_names_ = orig_output_tensor_names_;
602 }
603
BuildInferenceKernelsRecursive(kernel::LiteKernel * kernel,std::vector<kernel::LiteKernel * > * v)604 void TrainSession::BuildInferenceKernelsRecursive(kernel::LiteKernel *kernel, std::vector<kernel::LiteKernel *> *v) {
605 MS_ASSERT(kernel != nullptr);
606 MS_ASSERT(v != nullptr);
607 if (std::find(v->begin(), v->end(), kernel) == v->end()) { // kernel is not already in vector
608 for (auto in_node : kernel->in_kernels()) {
609 BuildInferenceKernelsRecursive(in_node, v);
610 }
611 if (!IsLossKernel(kernel)) v->push_back(kernel);
612 }
613 }
614
CompileTrainKernels()615 void TrainSession::CompileTrainKernels() {
616 train_kernels_.clear();
617 for (auto ori_kernel : kernels_) {
618 if (ori_kernel->subgraph_type() == kernel::kNotSubGraph) {
619 train_kernels_.push_back(ori_kernel);
620 } else {
621 auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(ori_kernel);
622 for (auto kernel : sub_graph->nodes()) {
623 train_kernels_.push_back(kernel);
624 }
625 }
626 }
627 }
628
CompileInferenceKernels()629 int TrainSession::CompileInferenceKernels() {
630 inference_kernels_.clear();
631 for (auto item : eval_output_node_map_) {
632 std::string kernel_name = item.first;
633 auto kernel = TSFindKernel(train_kernels_, kernel_name);
634 if (kernel == nullptr) {
635 MS_LOG(ERROR) << "kernel is nullptr";
636 return RET_ERROR;
637 }
638 BuildInferenceKernelsRecursive(kernel, &inference_kernels_);
639 }
640 if (inference_kernels_.size() == 0) {
641 inference_kernels_ = this->train_kernels_;
642 }
643 return RET_OK;
644 }
645
CompileOptimizedKernels()646 void TrainSession::CompileOptimizedKernels() {
647 std::vector<lite::Tensor *> out_tensor;
648 for (auto kernel : this->train_kernels_) {
649 if (IsOptimizer(kernel)) {
650 std::copy(kernel->in_tensors().begin(), kernel->in_tensors().end(), std::back_inserter(out_tensor));
651 if (cfg_.accumulate_gradients_) {
652 auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
653 optimizer->SetOptimizerMode(kernel::WeightUpdateMode::ACCUMULATE_GRADS);
654 }
655 }
656 }
657
658 for (auto kernel : this->train_kernels_) {
659 if (!IsOptimizer(kernel)) {
660 for (auto it : kernel->in_tensors()) {
661 if (std::find(out_tensor.begin(), out_tensor.end(), it) != out_tensor.end()) {
662 kernel->SetTrainable(true);
663 break;
664 }
665 }
666 }
667 }
668 }
669
SetLearningRate(float learning_rate)670 int TrainSession::SetLearningRate(float learning_rate) {
671 if (learning_rate < 0.0f) {
672 MS_LOG(ERROR) << "learning rate should more than 0";
673 return RET_ERROR;
674 }
675 for (auto kernel : this->train_kernels_) {
676 if (IsOptimizer(kernel)) {
677 auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
678 auto ret = optimizer->SetLearningRate(learning_rate);
679 if (ret != RET_OK) {
680 MS_LOG(ERROR) << kernel->name() << " failed to set learning rate";
681 return RET_ERROR;
682 }
683 }
684 }
685 return RET_OK;
686 }
687
GetLearningRate()688 float TrainSession::GetLearningRate() {
689 for (auto kernel : this->train_kernels_) {
690 if (IsOptimizer(kernel)) {
691 auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
692 return optimizer->GetLearningRate();
693 }
694 }
695 return 0.0;
696 }
697
GetOptimizerParams() const698 std::vector<tensor::MSTensor *> TrainSession::GetOptimizerParams() const {
699 std::vector<tensor::MSTensor *> params;
700 for (auto kernel : this->train_kernels_) {
701 if (IsOptimizer(kernel)) {
702 auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
703 auto kernelParams = optimizer->GetOptimizerParams();
704 for (size_t ix = 0; ix < kernelParams.size(); ix++) {
705 auto kernelParam = kernelParams[ix];
706 auto name = kernelParam->tensor_name();
707 bool found = false;
708 for (size_t iy = 0; iy < params.size(); iy++) {
709 if (params[iy]->tensor_name() == name) {
710 found = true;
711 break;
712 }
713 }
714 if (!found) {
715 params.push_back(kernelParam);
716 }
717 }
718 }
719 }
720 return params;
721 }
722
SetOptimizerParams(const std::vector<tensor::MSTensor * > & params)723 int TrainSession::SetOptimizerParams(const std::vector<tensor::MSTensor *> ¶ms) {
724 for (size_t ix = 0; ix < params.size(); ix++) {
725 auto param = params[ix];
726 if (param == nullptr) {
727 MS_LOG(ERROR) << "Param tensor " << param->tensor_name() << " is null.";
728 return RET_ERROR;
729 }
730 bool found = false;
731 for (auto kernel : this->train_kernels_) {
732 if (IsOptimizer(kernel)) {
733 auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
734 found = optimizer->SetOptimizerParams(param);
735 if (found) break;
736 }
737 }
738 if (!found) {
739 MS_LOG(ERROR) << "Tensor " << param->tensor_name() << " with " << param->ElementsNum() << " elelmts and type "
740 << param->data_type() << " is not a valid params tensor";
741 return RET_ERROR;
742 }
743 }
744 return RET_OK;
745 }
746
GetGradients() const747 std::vector<tensor::MSTensor *> TrainSession::GetGradients() const {
748 std::vector<tensor::MSTensor *> gradients;
749 for (auto kernel : this->train_kernels_) {
750 if (IsOptimizer(kernel)) {
751 auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
752 auto kernelGradint = optimizer->GetGradients();
753 if (kernelGradint != nullptr) {
754 gradients.push_back(kernelGradint);
755 }
756 }
757 }
758 return gradients;
759 }
760
ApplyGradients(const std::vector<tensor::MSTensor * > & gradients)761 int TrainSession::ApplyGradients(const std::vector<tensor::MSTensor *> &gradients) {
762 auto current_gradients = GetGradients();
763 if (current_gradients.size() != gradients.size()) {
764 MS_LOG(ERROR) << "gradients vector has wrong size " << gradients.size() << " instead of "
765 << current_gradients.size();
766 return RET_ERROR;
767 }
768 for (size_t ix = 0; ix < gradients.size(); ix++) {
769 auto gradient = gradients[ix];
770 if (gradient == nullptr) {
771 MS_LOG(ERROR) << "gradient tensor " << gradient->tensor_name() << " is null.";
772 return RET_ERROR;
773 }
774 bool found = false;
775 for (size_t iy = 0; iy < current_gradients.size(); iy++) {
776 auto current_gradient = current_gradients[iy];
777 if (current_gradient->tensor_name() == gradient->tensor_name()) {
778 found = true;
779 if (current_gradient->Size() == gradient->Size()) {
780 std::copy(static_cast<char *>(gradient->data()), static_cast<char *>(gradient->data()) + gradient->Size(),
781 static_cast<char *>(current_gradient->MutableData()));
782 } else {
783 MS_LOG(ERROR) << "gradient tensor " << gradient->tensor_name() << " has wrong size " << gradient->Size()
784 << " instead of " << current_gradient->Size();
785 return RET_ERROR;
786 }
787 break;
788 }
789 }
790 if (!found) {
791 MS_LOG(ERROR) << "gradient tensor " << gradient->tensor_name() << " not found";
792 return RET_ERROR;
793 }
794 }
795 for (auto kernel : this->train_kernels_) {
796 if (IsOptimizer(kernel)) {
797 auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
798 optimizer->set_grad_sum_valid();
799 auto ret = optimizer->OptimizerStep();
800 if (ret != RET_OK) {
801 MS_LOG(ERROR) << "failed to optimize model weights";
802 return ret;
803 }
804 }
805 }
806 for (size_t ix = 0; ix < current_gradients.size(); ix++) {
807 delete current_gradients[ix];
808 }
809 return RET_OK;
810 }
811
AdminSetupVirtualBatch(int virtual_batch_multiplier,float lr,float momentum)812 int TrainSession::AdminSetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum) {
813 auto mod =
814 (virtual_batch_multiplier <= 1) ? kernel::WeightUpdateMode::NORMAL : kernel::WeightUpdateMode::VIRTUAL_BATCH;
815 virtual_batch_multiplier_ = (virtual_batch_multiplier <= 1) ? 0 : virtual_batch_multiplier;
816 virtual_batch_idx_ = 0;
817
818 for (auto kernel : this->train_kernels_) {
819 if (IsOptimizer(kernel)) {
820 auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
821 if (optimizer->get_optimizer_mode() != kernel::WeightUpdateMode::NORMAL &&
822 optimizer->get_optimizer_mode() != kernel::WeightUpdateMode::VIRTUAL_BATCH) {
823 MS_LOG(ERROR) << kernel->name() << " failed to set optimizer mode, conflict with accumulate grads";
824 return RET_ERROR;
825 }
826 auto ret = optimizer->SetOptimizerMode(mod);
827 if (ret != RET_OK) {
828 MS_LOG(ERROR) << kernel->name() << " failed to set optimizer mode";
829 return RET_ERROR;
830 }
831 if (mod == kernel::WeightUpdateMode::VIRTUAL_BATCH) {
832 lr = (lr < 0.0f) ? (optimizer->GetLearningRate() / static_cast<float>(virtual_batch_multiplier_)) : lr;
833 ret = optimizer->SetLearningRate(lr);
834 } else {
835 ret = optimizer->RestoreDefaultLearningRate();
836 }
837 if (ret != RET_OK) {
838 MS_LOG(ERROR) << kernel->name() << " failed to set learning rate";
839 return RET_ERROR;
840 }
841 }
842
843 if (IsBN(kernel) && kernel->IsTrainable()) {
844 auto batchnorm = static_cast<kernel::BatchnormCPUKernel *>(kernel->kernel());
845 auto ret = RET_OK;
846 if (mod == kernel::WeightUpdateMode::VIRTUAL_BATCH) {
847 momentum = (momentum < 0.0f) ? (batchnorm->get_momentum() / virtual_batch_multiplier_) : momentum;
848 ret = batchnorm->set_momentum(momentum);
849 } else {
850 ret = batchnorm->RestoreDefaultMomentum();
851 }
852 if (ret != RET_OK) {
853 MS_LOG(ERROR) << kernel->name() << " failed to set momentum";
854 return RET_ERROR;
855 }
856 }
857 }
858 return RET_OK;
859 }
SetupVirtualBatch(int virtual_batch_multiplier,float lr,float momentum)860 int TrainSession::SetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum) {
861 int tmp = (virtual_batch_multiplier <= 1) ? 0 : virtual_batch_multiplier;
862 if (tmp != 0 && virtual_batch_multiplier_ != 0) {
863 AdminSetupVirtualBatch(0, lr, momentum);
864 }
865 return AdminSetupVirtualBatch(virtual_batch_multiplier, lr, momentum);
866 }
867
OptimizerStep()868 int TrainSession::OptimizerStep() {
869 for (auto kernel : this->train_kernels_) {
870 if (IsOptimizer(kernel)) {
871 auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
872 auto ret = optimizer->OptimizerStep();
873 if (ret != RET_OK) {
874 MS_LOG(ERROR) << kernel->name() << " failed to do optimize step";
875 return RET_ERROR;
876 }
877 }
878 }
879 return RET_OK;
880 }
881
IsLossKernel(const kernel::LiteKernel * kernel) const882 bool TrainSession::IsLossKernel(const kernel::LiteKernel *kernel) const {
883 return (kernel->type() == schema::PrimitiveType_SoftmaxCrossEntropyWithLogits ||
884 kernel->type() == schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits ||
885 kernel->type() == schema::PrimitiveType_SmoothL1Loss ||
886 kernel->type() == schema::PrimitiveType_SmoothL1LossGrad ||
887 kernel->type() == schema::PrimitiveType_SigmoidCrossEntropyWithLogits ||
888 kernel->type() == schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad) ||
889 kernel->name().find(cfg_.loss_name_) != std::string::npos;
890 }
891
IsGradKernel(const kernel::LiteKernel * kernel) const892 bool TrainSession::IsGradKernel(const kernel::LiteKernel *kernel) const {
893 return kernel->name().find(kGradName) != std::string::npos;
894 }
895
IsOptimizer(kernel::LiteKernel * kernel) const896 bool TrainSession::IsOptimizer(kernel::LiteKernel *kernel) const {
897 return ((kernel->type() == schema::PrimitiveType_Adam) || (kernel->type() == schema::PrimitiveType_SGD) ||
898 (kernel->type() == schema::PrimitiveType_ApplyMomentum));
899 }
900
IsMaskOutput(kernel::LiteKernel * kernel) const901 bool TrainSession::IsMaskOutput(kernel::LiteKernel *kernel) const {
902 return (IsOptimizer(kernel) || (kernel->type() == schema::PrimitiveType_Assign));
903 }
904
IsBN(kernel::LiteKernel * kernel) const905 bool TrainSession::IsBN(kernel::LiteKernel *kernel) const {
906 return ((kernel->type() == schema::PrimitiveType_BatchNorm) ||
907 (kernel->type() == schema::PrimitiveType_FusedBatchNorm));
908 }
909
Resize(const std::vector<tensor::MSTensor * > & inputs,const std::vector<std::vector<int>> & dims)910 int TrainSession::Resize(const std::vector<tensor::MSTensor *> &inputs, const std::vector<std::vector<int>> &dims) {
911 FreeWorkSpace();
912 if (tensors_data_ != nullptr) {
913 free(tensors_data_);
914 tensors_data_ = nullptr;
915 }
916 auto ret = lite::LiteSession::Resize(inputs, dims);
917 if (ret != RET_OK) {
918 MS_LOG(ERROR) << "train resize input failed.";
919 return RET_ERROR;
920 }
921 ret = AllocWorkSpace();
922 if (ret != RET_OK) {
923 MS_LOG(ERROR) << "failed to allocate space";
924 return RET_ERROR;
925 }
926 ret = AllocTensors(train_kernels_);
927 if (ret != RET_OK) {
928 MS_LOG(ERROR) << "train alloc failed after resize.";
929 return RET_ERROR;
930 }
931 return RET_OK;
932 }
933
FindUseInTensorKernel(std::vector<kernel::LiteKernel * > * use_in_tensor_kernels,const std::vector<lite::Tensor * > & kernel_in_tensors,const std::vector<kernel::LiteKernel * > & inference_kernels)934 int TrainSession::FindUseInTensorKernel(std::vector<kernel::LiteKernel *> *use_in_tensor_kernels,
935 const std::vector<lite::Tensor *> &kernel_in_tensors,
936 const std::vector<kernel::LiteKernel *> &inference_kernels) {
937 for (size_t i = 0; i < inference_kernels.size(); i++) {
938 for (size_t j = 0; j < kernel_in_tensors.size(); j++) {
939 if (IsContain(inference_kernels[i]->out_tensors(), kernel_in_tensors[j])) {
940 use_in_tensor_kernels->push_back(inference_kernels[i]);
941 }
942 }
943 }
944 return RET_OK;
945 }
946
FindExportKernels(std::vector<kernel::LiteKernel * > * export_kernels,const std::vector<std::string> & export_output_tensor_names,const std::vector<kernel::LiteKernel * > & inference_kernels)947 int TrainSession::FindExportKernels(std::vector<kernel::LiteKernel *> *export_kernels,
948 const std::vector<std::string> &export_output_tensor_names,
949 const std::vector<kernel::LiteKernel *> &inference_kernels) {
950 std::vector<std::string> all_kernel_name = {};
951 std::transform(inference_kernels.begin(), inference_kernels.end(), std::back_inserter(all_kernel_name),
952 [](kernel::LiteKernel *kernel) { return kernel->name(); });
953 std::queue<std::string> need_kernel_names;
954 // Find the kernel name according to the tensor name
955 for (auto &kernel : inference_kernels) {
956 if (std::any_of(kernel->out_tensors().begin(), kernel->out_tensors().end(), [&](lite::Tensor *out_tensor) {
957 return IsContain(export_output_tensor_names, out_tensor->tensor_name());
958 })) {
959 need_kernel_names.push(kernel->name());
960 }
961 }
962 if (need_kernel_names.size() == 0) {
963 MS_LOG(ERROR) << "can not find tensor";
964 return RET_ERROR;
965 }
966 // find all kernel
967 while (!need_kernel_names.empty()) {
968 auto kernel_name = need_kernel_names.front();
969 need_kernel_names.pop();
970 auto it = find(all_kernel_name.begin(), all_kernel_name.end(), kernel_name);
971 if (it == all_kernel_name.end()) {
972 MS_LOG(ERROR) << "not find kernel name in export trained model.";
973 return RET_ERROR;
974 }
975 auto kernel = inference_kernels[it - all_kernel_name.begin()];
976 if (!IsContain(*export_kernels, kernel)) {
977 export_kernels->push_back(kernel);
978 }
979 auto kernel_in_tensors = kernel->in_tensors();
980 std::vector<kernel::LiteKernel *> use_in_tensor_kernels;
981 auto status = FindUseInTensorKernel(&use_in_tensor_kernels, kernel_in_tensors, inference_kernels);
982 if (status != RET_OK) {
983 MS_LOG(ERROR) << "FindUseInTensorKernel failed.";
984 return RET_ERROR;
985 }
986 for (size_t i = 0; i < use_in_tensor_kernels.size(); i++) {
987 need_kernel_names.push(use_in_tensor_kernels[i]->name());
988 }
989 }
990 return RET_OK;
991 }
992
Export(const std::string & file_name,ModelType model_type,QuantizationType quant_type,FormatType format,std::vector<std::string> out_put_tensor_name)993 int TrainSession::Export(const std::string &file_name, ModelType model_type, QuantizationType quant_type,
994 FormatType format, std::vector<std::string> out_put_tensor_name) {
995 if (file_name.empty()) {
996 MS_LOG(ERROR) << "File name cannot be empty";
997 return RET_ERROR;
998 }
999 if (model_type > mindspore::lite::MT_INFERENCE || model_type < mindspore::lite::MT_TRAIN) {
1000 MS_LOG(ERROR) << "Export model type parameter error";
1001 return RET_ERROR;
1002 }
1003 if (quant_type < mindspore::lite::QT_DEFAULT || quant_type > mindspore::lite::QT_WEIGHT) {
1004 MS_LOG(ERROR) << "Export quant type parameter error";
1005 return RET_ERROR;
1006 }
1007 if (format != FT_FLATBUFFERS) {
1008 MS_LOG(ERROR) << "Currently only flatbuffer format is supported";
1009 return RET_ERROR;
1010 }
1011
1012 bool orig_train_state = IsTrain();
1013 Eval();
1014 TrainExport texport(file_name);
1015 int status = texport.ExportInit(model_.get()->name_, model_.get()->version_);
1016 if (status != RET_OK) {
1017 MS_LOG(ERROR) << "cannot init export";
1018 return status;
1019 }
1020
1021 if (!out_put_tensor_name.empty() && model_type == MT_INFERENCE) {
1022 std::vector<kernel::LiteKernel *> export_kernels = {};
1023 status = FindExportKernels(&export_kernels, out_put_tensor_name, inference_kernels_);
1024 if (status != RET_OK) {
1025 MS_LOG(ERROR) << "FindExportKernels failed.";
1026 return RET_ERROR;
1027 }
1028 status = texport.ExportNet(export_kernels, tensors_, out_put_tensor_name, model_.get(), quant_type);
1029 } else {
1030 status = texport.ExportNet((model_type == MT_TRAIN) ? train_kernels_ : inference_kernels_, tensors_,
1031 (model_type == MT_TRAIN) ? train_output_tensor_names_ : eval_output_tensor_names_,
1032 model_.get(), quant_type);
1033 }
1034
1035 if (status != RET_OK) {
1036 MS_LOG(ERROR) << "cannot export Network";
1037 return status;
1038 }
1039 status = texport.SaveToFile();
1040 if (status != RET_OK) {
1041 MS_LOG(ERROR) << "failed to save to " << file_name;
1042 return status;
1043 }
1044 if (orig_train_state) Train();
1045 return status;
1046 }
GetFeatureMaps() const1047 std::vector<tensor::MSTensor *> TrainSession::GetFeatureMaps() const {
1048 std::vector<tensor::MSTensor *> features;
1049 for (auto cur_tensor : this->tensors_) {
1050 if (cur_tensor->IsConst() && cur_tensor->data_type() == kNumberTypeFloat32) {
1051 features.push_back(cur_tensor);
1052 }
1053 }
1054 return features;
1055 }
1056
UpdateFeatureMaps(const std::vector<tensor::MSTensor * > & features_map)1057 int TrainSession::UpdateFeatureMaps(const std::vector<tensor::MSTensor *> &features_map) {
1058 for (auto feature : features_map) {
1059 bool find = false;
1060 for (auto tensor : tensors_) {
1061 if (!tensor->IsConst() || tensor->data_type() != kNumberTypeFloat32) {
1062 continue;
1063 }
1064 if (feature->tensor_name() != tensor->tensor_name()) {
1065 continue;
1066 }
1067 if (feature->Size() != tensor->Size()) {
1068 MS_LOG(ERROR) << "feature name:" << feature->tensor_name() << ",len diff:"
1069 << "old is:" << tensor->Size() << "new is:" << feature->Size();
1070 return RET_ERROR;
1071 }
1072 find = true;
1073 memcpy(tensor->data(), feature->data(), tensor->Size());
1074 }
1075 if (!find) {
1076 MS_LOG(ERROR) << "cannot find feature:" << feature->tensor_name() << ",update failed";
1077 return RET_ERROR;
1078 }
1079 }
1080 return RET_OK;
1081 }
1082 } // namespace lite
1083
CreateTrainSession(const std::string & fn,const lite::Context * context,bool train_mode,const lite::TrainCfg * cfg)1084 session::LiteSession *session::TrainSession::CreateTrainSession(const std::string &fn, const lite::Context *context,
1085 bool train_mode, const lite::TrainCfg *cfg) {
1086 if (context == nullptr) {
1087 MS_LOG(ERROR) << "context cannot be nullptr";
1088 return nullptr;
1089 }
1090 auto session = std::make_unique<lite::TrainSession>();
1091 if (session == nullptr) {
1092 MS_LOG(ERROR) << "create session failed";
1093 return nullptr;
1094 }
1095 if (context->allocator == nullptr) {
1096 const_cast<lite::Context *>(context)->allocator = std::make_shared<StaticAllocator>();
1097 if (context->allocator == nullptr) {
1098 MS_LOG(ERROR) << " cannot convert to static allocation";
1099 }
1100 }
1101
1102 auto *inner_context = new (std::nothrow) mindspore::lite::InnerContext(context);
1103 auto ret = session->Init(inner_context, cfg);
1104 if (ret != mindspore::lite::RET_OK) {
1105 MS_LOG(ERROR) << "init session failed";
1106 return nullptr;
1107 }
1108
1109 std::string filename = fn;
1110 if (filename.substr(filename.find_last_of(".") + 1) != "ms") {
1111 filename = filename + ".ms";
1112 }
1113
1114 auto model = std::shared_ptr<lite::Model>(lite::Model::Import(filename.c_str()));
1115 if (model == nullptr) {
1116 MS_LOG(ERROR) << "create model for train session failed " << filename;
1117 return nullptr;
1118 }
1119
1120 ret = session->CompileTrainGraph(model);
1121 if (ret != mindspore::lite::RET_OK) {
1122 MS_LOG(ERROR) << "Compiling Train Graph session failed";
1123 return nullptr;
1124 }
1125
1126 if (train_mode) {
1127 ret = session->Train();
1128 } else {
1129 ret = session->Eval();
1130 }
1131 if (ret != mindspore::lite::RET_OK) {
1132 MS_LOG(ERROR) << "Could not switch to Train Modei " << train_mode;
1133 return nullptr;
1134 }
1135 return session.release();
1136 }
1137 } // namespace mindspore
1138