1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/train/train_session.h"
18 #include <sys/stat.h>
19 #include <algorithm>
20 #include <utility>
21 #include <vector>
22 #include <iostream>
23 #include <fstream>
24 #include <memory>
25 #include <queue>
26 #include <map>
27 #include <set>
28 #include "include/errorcode.h"
29 #include "src/litert/lite_model.h"
30 #include "src/litert/kernel_exec_util.h"
31 #include "src/tensor.h"
32 #include "src/litert/kernel_registry.h"
33 #include "src/common/prim_util.h"
34 #include "src/common/tensor_util.h"
35 #include "src/common/utils.h"
36 #include "src/train/optimizer_kernel.h"
37 #include "src/train/train_utils.h"
38 #include "src/train/train_export.h"
39 #include "src/train/opt_allocator.h"
40 #include "src/train/static_allocator.h"
41 #include "src/train/train_populate_parameter.h"
42 #include "src/train/train_populate_parameter_v0.h"
43
44 namespace mindspore {
45 namespace lite {
46 namespace {
FreeGradients(const std::vector<lite::Tensor * > & gradients)47 void FreeGradients(const std::vector<lite::Tensor *> &gradients) {
48 for (auto &gradient : gradients) {
49 delete gradient;
50 }
51 } // namespace
52
AddNonConstTrainableParams(const std::vector<kernel::KernelExec * > & in_kernels,kernel::OptimizerKernel * optimizer,std::vector<lite::Tensor * > * params)53 void AddNonConstTrainableParams(const std::vector<kernel::KernelExec *> &in_kernels, kernel::OptimizerKernel *optimizer,
54 std::vector<lite::Tensor *> *params) {
55 auto indices = optimizer->GetTrainableParamsIdxs();
56 if (params->size() == indices.size()) {
57 return;
58 }
59 for (size_t ix = 0; ix < indices.size(); ix++) {
60 auto param = optimizer->in_tensors().at(indices[ix]);
61 if (param->IsConst()) {
62 continue;
63 }
64 for (size_t i = 0; i < in_kernels.size(); i++) {
65 auto out_tensors = in_kernels.at(i)->out_tensors();
66 if (std::find(out_tensors.begin(), out_tensors.end(), param) != out_tensors.end() &&
67 !in_kernels.at(i)->in_tensors().empty()) {
68 auto filtered_tensor = in_kernels.at(i)->in_tensors().at(FIRST_INPUT);
69 if (filtered_tensor->IsConst()) {
70 params->emplace_back(filtered_tensor);
71 break;
72 }
73 }
74 }
75 }
76 }
77 } // namespace
78 const char *kGradName = "Gradients";
79 const char *kOptimizerName = "optimizer";
80 constexpr auto kObfNodeName = "obf_op-obf_mul";
81 constexpr size_t kFloatSize = 4;
82 constexpr int kDataIndex = 1;
83
TrainSession()84 TrainSession::TrainSession() {
85 is_train_session_ = true;
86 InitCallBack();
87 }
88
TrainInit(const std::shared_ptr<InnerContext> & context,const TrainCfg * train_cfg)89 int TrainSession::TrainInit(const std::shared_ptr<InnerContext> &context, const TrainCfg *train_cfg) {
90 if (train_cfg != nullptr) {
91 if (train_cfg->mix_precision_cfg_.loss_scale_ <= 0) {
92 MS_LOG(ERROR) << "illegal loss scale configuration";
93 return RET_NULL_PTR;
94 }
95 cfg_ = *train_cfg;
96 }
97 if (context == nullptr) {
98 MS_LOG(ERROR) << "context cannot be nullptr";
99 return RET_NULL_PTR;
100 }
101 allocator_ = context->allocator;
102 return lite::LiteSession::Init(context);
103 }
104
ReplaceOps()105 std::vector<CreatorOp> TrainSession::ReplaceOps() {
106 const std::vector<CreatorOp> replace = {
107 // currently no ops are Hijacked by TrainSession
108 };
109 mindspore::lite::KernelRegistry *reg = mindspore::lite::KernelRegistry::GetInstance();
110 std::vector<CreatorOp> results;
111 for (auto v : replace) {
112 const CreatorOp cl = make_tuple(std::get<0>(v), reg->GetCreator(std::get<0>(v)));
113 results.push_back(cl);
114 reg->RegKernel(std::get<0>(v), std::get<1>(v));
115 }
116 return results;
117 }
118
RestoreOps(const std::vector<CreatorOp> & restore)119 void TrainSession::RestoreOps(const std::vector<CreatorOp> &restore) {
120 mindspore::lite::KernelRegistry *reg = mindspore::lite::KernelRegistry::GetInstance();
121 for (auto v : restore) {
122 reg->RegKernel(std::get<0>(v), std::get<1>(v));
123 }
124 }
125
AllocWorkSpace()126 int TrainSession::AllocWorkSpace() {
127 size_t workspace_size = 0;
128 for (auto kernel : this->train_kernels_) {
129 if (workspace_size < static_cast<kernel::LiteKernel *>(kernel->kernel())->workspace_size()) {
130 workspace_size = static_cast<kernel::LiteKernel *>(kernel->kernel())->workspace_size();
131 }
132 }
133 workspace_ = malloc(workspace_size);
134 if (workspace_ == nullptr) {
135 MS_LOG(ERROR) << "cannot allocate " << workspace_size << " for workspace";
136 return RET_ERROR;
137 }
138 for (auto kernel : this->train_kernels_) {
139 static_cast<kernel::LiteKernel *>(kernel->kernel())->set_workspace(workspace_);
140 }
141 return RET_OK;
142 }
143
FreeWorkSpace()144 void TrainSession::FreeWorkSpace() {
145 if (workspace_ != nullptr) {
146 free(workspace_);
147 workspace_ = nullptr;
148 }
149 for (auto kernel : this->train_kernels_) {
150 static_cast<kernel::LiteKernel *>(kernel->kernel())->FreeWorkspace();
151 }
152 }
153
InitCallBack()154 int TrainSession::InitCallBack() {
155 sched_mix_precision_callback_ = [&](const LiteGraph::Node *node) {
156 if (!context_->IsCpuFloat16Enabled()) {
157 return false;
158 }
159 bool force_fp16 = false;
160 auto node_type = GetPrimitiveType(node->primitive_, SCHEMA_VERSION::SCHEMA_CUR);
161 if (node_type == schema::PrimitiveType_Cast) {
162 schema::Tensor *tensor = model_.get()->graph_.all_tensors_.at(node->input_indices_[0]);
163 if (tensor->dataType() == kNumberTypeFloat16) {
164 force_fp16 = true;
165 } else if (tensor->dataType() == kNumberTypeFloat32) {
166 return false;
167 }
168 } else {
169 auto in_size = node->input_indices_.size();
170 for (std::size_t k = 0; k < in_size; k++) {
171 schema::Tensor *tensor = model_->graph_.all_tensors_.at(node->input_indices_[k]);
172 if ((tensor->dataType() == kNumberTypeFloat16) && (tensor->nodeType() == NodeType_ValueNode)) {
173 force_fp16 = true;
174 break;
175 }
176 }
177 }
178
179 const auto &node_name = node->name_;
180 bool is_fp16 = true;
181 if (!force_fp16) {
182 // optimizer runs in fp32
183 if (node_name.find(kOptimizerName) != std::string::npos) {
184 is_fp16 = false;
185 }
186 // loss function runs in fp32
187 auto v = get_loss_name();
188 for (auto &s : v) {
189 if (node_name.find(s) != std::string::npos) {
190 is_fp16 = false;
191 break;
192 }
193 }
194 // run bn according to user configuration
195 if ((cfg_.mix_precision_cfg_.keep_batchnorm_fp32_) &&
196 (node_type == schema::PrimitiveType_FusedBatchNorm || node_type == schema::PrimitiveType_BatchNorm ||
197 node_type == schema::PrimitiveType_BatchNormGrad)) {
198 is_fp16 = false;
199 }
200 }
201 MS_LOG(DEBUG) << "Debug: " << node_name << ((is_fp16) ? " fp16" : " fp32");
202 return is_fp16;
203 };
204 return RET_OK;
205 }
206
AllocTensors(const std::vector<kernel::KernelExec * > & kernels)207 int TrainSession::AllocTensors(const std::vector<kernel::KernelExec *> &kernels) {
208 if (!IS_STATIC_ALLOCATOR(allocator_)) return RET_OK;
209 OptAllocator allocator;
210 std::unordered_map<lite::Tensor *, int> ref_count;
211 std::unordered_map<lite::Tensor *, size_t> offset_map;
212 int counter = 0;
213 uint32_t input_idx = 0;
214 for (auto &kernel : kernels) {
215 for (size_t i = 0; i < kernel->out_tensors().size(); i++) {
216 auto tensor = kernel->out_tensors().at(i);
217 bool in_place = false;
218 if (counter != 0) {
219 in_place = IsInPlaceTensor(kernel, i, ref_count, &input_idx);
220 }
221 counter++;
222 size_t offset;
223 if (in_place) {
224 offset = GetInplaceTensorOffset(kernel, offset_map, &ref_count, input_idx);
225 } else {
226 size_t size = tensor->Size();
227 offset = allocator.Malloc(size);
228 }
229 offset_map[tensor] = offset;
230 ref_count[tensor] = tensor->init_ref_count();
231 }
232 for (auto tensor : kernel->in_tensors()) {
233 if (tensor->category() == lite::Category::VAR) {
234 int count = ref_count[tensor] - 1;
235 ref_count[tensor] = count;
236 if (count == 0) {
237 allocator.Free(offset_map[tensor]);
238 }
239 }
240 }
241 }
242 // Set Tensor data
243 auto size = allocator.total_size();
244 if (size > tensors_data_size_) {
245 free(tensors_data_);
246 tensors_data_ = nullptr;
247 }
248 if (tensors_data_ == nullptr) {
249 auto buf = malloc(size);
250 if (buf == nullptr) {
251 MS_LOG(ERROR) << "cannot allocate buffer size" << size;
252 return RET_ERROR;
253 }
254 StaticAllocator *alloc = reinterpret_cast<StaticAllocator *>(allocator_.get());
255 alloc->SetContex(buf, size);
256 tensors_data_ = buf;
257 tensors_data_size_ = size;
258 }
259 for (auto kernel : train_kernels_) {
260 for (auto tensor : kernel->out_tensors()) {
261 auto it = offset_map.find(tensor);
262 if (it != offset_map.end()) {
263 tensor->set_data(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(tensors_data_) + it->second));
264 }
265 }
266 }
267 return RET_OK;
268 }
269
CompileGraph(lite::Model * model)270 int TrainSession::CompileGraph(lite::Model *model) { return lite::RET_ERROR; }
271
CompileTrainGraph(std::shared_ptr<Model> model)272 int TrainSession::CompileTrainGraph(std::shared_ptr<Model> model) {
273 model_ = model;
274 auto restore = ReplaceOps();
275 sched_cb_ = std::make_unique<SchedulerCb>(sched_mix_precision_callback_);
276 if (sched_cb_ == nullptr) {
277 MS_LOG(ERROR) << "Failed to create SchedulerCb node";
278 return RET_ERROR;
279 }
280
281 if (reinterpret_cast<LiteModel *>(model_.get())->GetSchemaVersion() == SCHEMA_VERSION::SCHEMA_CUR) {
282 kernel::PopulateTrainParameters();
283 }
284
285 auto ret = lite::LiteSession::CompileGraph(model_.get());
286 if (ret != RET_OK) {
287 MS_LOG(ERROR) << "failed to compile train model";
288 return RET_ERROR;
289 }
290 orig_output_node_map_ = output_node_map_;
291 orig_output_tensor_map_ = output_tensor_map_;
292 orig_output_tensor_names_ = output_tensor_names_;
293 for (auto inTensor : inputs_) inTensor->MutableData();
294 RestoreOps(restore);
295 CompileTrainKernels(); // Prepare a list of train kernels
296 CompileOptimizedKernels(); // Prepare a list of kernels which are optimized (weight update step)
297 CompileTrainableParams(); // Prepare trainable parameters of optimizers
298 CompileTrainOutputs(); // prepare outputs in train mode
299 CompileEvalOutputs(); // prepare outputs in eval mode
300 // Prepare a list of eval kernels
301 if (CompileInferenceKernels() != RET_OK) {
302 MS_LOG(WARNING) << "CompileInferenceKernels failed.";
303 return RET_ERROR;
304 }
305 ret = AllocWorkSpace();
306 if (ret != RET_OK) {
307 MS_LOG(ERROR) << "failed to allocate space";
308 return RET_ERROR;
309 }
310 ret = AllocTensors(train_kernels_);
311 if (ret != RET_OK) {
312 MS_LOG(ERROR) << "failed to allocate space";
313 return RET_ERROR;
314 }
315 // Prepare a list of kernels which are const folded
316 MS_CHECK_TRUE_MSG(CompileConstFoldedKernels() == RET_OK, RET_ERROR, "CompileConstFoldedKernels failed.");
317 return RET_OK;
318 }
319
~TrainSession()320 TrainSession::~TrainSession() {
321 FreeWorkSpace();
322 if (tensors_data_ != nullptr) {
323 free(tensors_data_);
324 tensors_data_ = nullptr;
325 }
326 }
327
ExecKernels(const KernelCallBack & before,const KernelCallBack & after,const std::vector<kernel::KernelExec * > & run_kernels)328 int TrainSession::ExecKernels(const KernelCallBack &before, const KernelCallBack &after,
329 const std::vector<kernel::KernelExec *> &run_kernels) {
330 for (auto *kernel : run_kernels) {
331 MS_ASSERT(kernel != nullptr);
332 auto ret = kernel->Execute(before, after);
333 if (RET_OK != ret) {
334 MS_LOG(ERROR) << "Execute kernel failed, name: " << kernel->name();
335 return ret;
336 }
337 }
338 return RET_OK;
339 }
340
RestoreTensorData()341 void TrainSession::RestoreTensorData() {
342 for (auto &restored_origin_tensor : restored_origin_tensors_) {
343 auto *origin_tensor = restored_origin_tensor.first;
344 auto *restored_tensor = restored_origin_tensor.second;
345 MS_ASSERT(origin_tensor != nullptr);
346 MS_ASSERT(restored_tensor != nullptr);
347
348 bool own_data = restored_tensor->own_data();
349 if (origin_tensor->data() == nullptr) {
350 restored_tensor->FreeData();
351 } else {
352 origin_tensor->FreeData();
353 }
354 origin_tensor->set_data_type(restored_tensor->data_type());
355 origin_tensor->set_data(restored_tensor->data());
356 origin_tensor->set_own_data(own_data);
357 }
358 }
359
FreeRestoreTensors()360 void TrainSession::FreeRestoreTensors() {
361 for (auto &restored_origin_tensor : restored_origin_tensors_) {
362 auto *restored_tensor = restored_origin_tensor.second;
363 restored_tensor->set_data(nullptr);
364 delete (restored_tensor);
365 }
366 restored_origin_tensors_.clear();
367 }
368
IsLossTensor(Tensor * tensor)369 bool TrainSession::IsLossTensor(Tensor *tensor) {
370 MS_ASSERT(tensor != nullptr);
371 bool isLoss = false;
372 auto t_n = tensor->tensor_name();
373 auto v = get_loss_name();
374 for (auto &s : v) {
375 if (t_n.find(s) != std::string::npos) {
376 isLoss = true;
377 break;
378 }
379 }
380 return isLoss;
381 }
382
AllInputsNeedScale(kernel::KernelExec * kernel)383 bool TrainSession::AllInputsNeedScale(kernel::KernelExec *kernel) {
384 auto type = kernel->type();
385 bool is_scale = false;
386 switch (type) {
387 case schema::PrimitiveType_AbsGrad:
388 case schema::PrimitiveType_AddFusion:
389 case schema::PrimitiveType_SubFusion:
390 case schema::PrimitiveType_AddN:
391 for (auto &tensor : kernel->in_tensors()) {
392 is_scale = is_scale || tensor->IsScale();
393 }
394 return (is_scale);
395 default:
396 return false;
397 }
398 return false;
399 }
400
MixPrecisionPreProcess(kernel::KernelExec * kernel,float scale)401 int TrainSession::MixPrecisionPreProcess(kernel::KernelExec *kernel, float scale) {
402 auto kernel_type = kernel->desc().data_type;
403 auto all_scale = AllInputsNeedScale(kernel);
404
405 for (auto &tensor : kernel->in_tensors()) {
406 if ((tensor->IsScale() == false) && ((!IsLossKernel(kernel) && IsLossTensor(tensor)) || (all_scale == true))) {
407 ScaleTensor(tensor, scale);
408 }
409 // adjust tensor data type
410 if (tensor->data_type() != kernel_type) {
411 auto restore_tensor = CastTensor(tensor, kernel_type, this->context_->device_and_pkg_support_fp16_);
412 if (restore_tensor != nullptr) {
413 restored_origin_tensors_[tensor] = restore_tensor;
414 }
415 }
416 }
417 return RET_OK;
418 }
419
MixPrecisionPostProcess(kernel::KernelExec * kernel)420 int TrainSession::MixPrecisionPostProcess(kernel::KernelExec *kernel) {
421 RestoreTensorData();
422 FreeRestoreTensors();
423
424 float scale = 1.0f;
425 auto all_scale = AllInputsNeedScale(kernel);
426 for (auto &tensor : kernel->in_tensors()) {
427 if (tensor->IsScale()) {
428 scale *= tensor->get_scale();
429 if (all_scale) {
430 break;
431 }
432 }
433 }
434 for (auto &tensor : kernel->out_tensors()) {
435 tensor->set_scale(scale);
436 }
437
438 for (auto &tensor : kernel->in_tensors()) {
439 if (tensor->IsScale() && ((!IsLossKernel(kernel) && IsLossTensor(tensor)) || all_scale)) {
440 ScaleTensor(tensor, 1.0f / scale);
441 }
442 }
443 return RET_OK;
444 }
445
MixPrecisionExecKernels(const KernelCallBack & before,const KernelCallBack & after,const std::vector<kernel::KernelExec * > & run_kernels)446 int TrainSession::MixPrecisionExecKernels(const KernelCallBack &before, const KernelCallBack &after,
447 const std::vector<kernel::KernelExec *> &run_kernels) {
448 float scale = cfg_.mix_precision_cfg_.loss_scale_;
449 for (auto *kernel : run_kernels) {
450 MS_ASSERT(kernel != nullptr);
451 auto ret = MixPrecisionPreProcess(kernel, scale);
452 if (ret != RET_OK) {
453 MS_LOG(ERROR) << "MixPrecisionPreProcess failed.";
454 return RET_ERROR;
455 }
456 ret = kernel->Execute(before, after);
457 if (RET_OK != ret) {
458 MixPrecisionPostProcess(kernel);
459 // decrease loss scale in case of nan or inf
460 if (ret == RET_OUT_OF_TENSOR_RANGE) {
461 bool is_dynamic_scale = cfg_.mix_precision_cfg_.dynamic_loss_scale_;
462 cfg_.mix_precision_cfg_.loss_scale_ = std::max(((is_dynamic_scale) ? (scale / 2.f) : scale), 1.0f);
463 num_of_not_nan_iter_ = 0;
464 return RET_OK;
465 }
466 MS_LOG(ERROR) << "Execute kernel failed, name: " << kernel->name();
467 return ret;
468 }
469 MixPrecisionPostProcess(kernel);
470 }
471 // increase dynamic loss scale if pass pass threshold
472 if (cfg_.mix_precision_cfg_.dynamic_loss_scale_) {
473 num_of_not_nan_iter_++;
474 if (num_of_not_nan_iter_ >= cfg_.mix_precision_cfg_.num_of_not_nan_iter_th_) {
475 cfg_.mix_precision_cfg_.loss_scale_ = std::min((cfg_.mix_precision_cfg_.loss_scale_ * 2.0f), 65536.0f);
476 num_of_not_nan_iter_ = 0;
477 }
478 }
479
480 // cast output to FP32
481 if (train_mode_ == false) {
482 for (auto t : this->outputs_) {
483 if (t->data_type() == kNumberTypeFloat16) {
484 auto restore = CastTensor(t, kNumberTypeFloat32, this->context_->device_and_pkg_support_fp16_);
485 delete restore;
486 }
487 }
488 }
489 return RET_OK;
490 }
491
RunGraph(const KernelCallBack & before,const KernelCallBack & after)492 int TrainSession::RunGraph(const KernelCallBack &before, const KernelCallBack &after) {
493 // check inputs
494 auto ret = CheckTensorsInvalid(inputs_);
495 if (ret != RET_OK) {
496 MS_LOG(ERROR) << "CheckInputs failed";
497 return ret;
498 }
499
500 // build out tensor
501 this->outputs_.clear();
502 for (auto &ms_tensors : output_node_map_) {
503 for (auto &ms_tensor : ms_tensors.second) {
504 auto lite_tensor = static_cast<lite::Tensor *>(ms_tensor);
505 this->outputs_.push_back(lite_tensor);
506 }
507 }
508
509 if (this->context_ == nullptr) {
510 MS_LOG(ERROR) << "context is null";
511 return lite::RET_NULL_PTR;
512 }
513 auto &run_kernels = (train_mode_) ? train_kernels_ : inference_kernels_;
514 if (context_->IsCpuFloat16Enabled()) {
515 ret = MixPrecisionExecKernels(before, after, run_kernels);
516 } else {
517 ret = ExecKernels(before, after, run_kernels);
518 }
519 if (ret != RET_OK) {
520 MS_LOG(ERROR) << "failed to run model kernels";
521 return ret;
522 }
523
524 if (train_mode_ && (virtual_batch_multiplier_ != 0)) {
525 virtual_batch_idx_++;
526 if (virtual_batch_idx_ >= virtual_batch_multiplier_) {
527 virtual_batch_idx_ = 0;
528 ret = OptimizerStep();
529 if (ret != RET_OK) {
530 MS_LOG(ERROR) << "failed to optimize model weights";
531 return ret;
532 }
533 }
534 }
535 return RET_OK;
536 }
537
Train()538 int TrainSession::Train() {
539 // shift kernels to train mode
540 train_mode_ = true;
541 virtual_batch_idx_ = 0;
542 for (auto &kernel : this->train_kernels_) {
543 MS_ASSERT(kernel != nullptr);
544 auto ret = kernel->Train();
545 if (ret != RET_OK) {
546 MS_LOG(ERROR) << kernel->name() << " failed to set train mode";
547 return RET_ERROR;
548 }
549 }
550 // set train outputs
551 output_node_map_ = train_output_node_map_;
552 output_tensor_map_ = train_output_tensor_map_;
553 output_tensor_names_ = train_output_tensor_names_;
554 kernel::KernelExecUtil::InitTensorInitRefCount(train_kernels_);
555 for (auto &ms_tensors : eval_output_node_map_) { // Allow to look at prediction also during training
556 for (auto &ms_tensor : ms_tensors.second) {
557 lite::Tensor *lite_tensor = static_cast<lite::Tensor *>(ms_tensor);
558 lite_tensor->set_init_ref_count(lite_tensor->init_ref_count() + 1);
559 }
560 }
561 // allocate tensors
562 auto ret = AllocTensors(train_kernels_);
563 if (ret != RET_OK) {
564 MS_LOG(ERROR) << "failed to allocate tensor space";
565 return RET_ERROR;
566 }
567 return RET_OK;
568 }
569
Eval()570 int TrainSession::Eval() {
571 // shift kernels to eval mode
572 train_mode_ = false;
573 virtual_batch_idx_ = 0;
574 for (auto &kernel : this->train_kernels_) {
575 MS_ASSERT(kernel != nullptr);
576 auto ret = kernel->Eval();
577 if (ret != RET_OK) {
578 MS_LOG(ERROR) << kernel->name() << " failed to set eval mode";
579 return RET_ERROR;
580 }
581 }
582 // set eval outputs
583 output_node_map_ = eval_output_node_map_;
584 output_tensor_map_ = eval_output_tensor_map_;
585 output_tensor_names_ = eval_output_tensor_names_;
586 kernel::KernelExecUtil::InitTensorInitRefCount(inference_kernels_);
587 for (auto &ms_tensors : eval_output_node_map_) {
588 for (auto &ms_tensor : ms_tensors.second) {
589 lite::Tensor *lite_tensor = static_cast<lite::Tensor *>(ms_tensor);
590 lite_tensor->set_init_ref_count(lite_tensor->init_ref_count() + 1);
591 }
592 }
593 auto ret = AllocTensors(inference_kernels_);
594 if (ret != RET_OK) {
595 MS_LOG(ERROR) << "failed to allocate space";
596 return RET_ERROR;
597 }
598 return RET_OK;
599 }
600
CompileEvalOutputs()601 void TrainSession::CompileEvalOutputs() {
602 eval_output_node_map_.clear();
603 eval_output_tensor_map_.clear();
604 eval_output_tensor_names_.clear();
605 for (auto kernel : this->train_kernels_) {
606 if (!IsLossKernel(kernel) || IsGradKernel(kernel)) {
607 continue;
608 }
609 // if LossKernel and not GradKernel, deal with outputs
610 for (auto in_kernel : kernel->in_kernels()) {
611 bool is_loss = IsLossInKernel(in_kernel);
612 if (is_loss) {
613 continue;
614 }
615 // insert if not already in
616 auto out_tensors = TSFindTensors(in_kernel, kernel);
617 if (eval_output_node_map_.find(in_kernel->name()) != eval_output_node_map_.end()) {
618 auto exist_out_tensors = eval_output_node_map_[in_kernel->name()];
619 auto kernel_all_out_tensors = in_kernel->out_tensors();
620 eval_output_node_map_[in_kernel->name()] = {};
621 for (auto tensor : kernel_all_out_tensors) {
622 if (std::find(out_tensors.begin(), out_tensors.end(), tensor) != out_tensors.end() ||
623 std::find(exist_out_tensors.begin(), exist_out_tensors.end(), tensor) != exist_out_tensors.end()) {
624 eval_output_node_map_[in_kernel->name()].emplace_back(tensor);
625 }
626 }
627 } else {
628 eval_output_node_map_[in_kernel->name()] = out_tensors;
629 }
630 for (auto out_tensor : out_tensors) {
631 auto index = TSFindTensor(tensors_, out_tensor);
632 if (std::find(eval_output_tensor_names_.begin(), eval_output_tensor_names_.end(), out_tensor->tensor_name()) !=
633 eval_output_tensor_names_.end() ||
634 std::find(eval_output_tensor_names_.begin(), eval_output_tensor_names_.end(), std::to_string(index)) !=
635 eval_output_tensor_names_.end()) {
636 continue;
637 }
638 if (index != tensors_.size()) {
639 if (!out_tensor->tensor_name().empty()) {
640 eval_output_tensor_map_.insert(std::make_pair(out_tensor->tensor_name(), out_tensor));
641 eval_output_tensor_names_.emplace_back(out_tensor->tensor_name());
642 } else {
643 eval_output_tensor_map_.insert(std::make_pair(std::to_string(index), out_tensor));
644 eval_output_tensor_names_.emplace_back(std::to_string(index));
645 }
646 }
647 }
648 }
649 }
650 if (eval_output_node_map_.empty()) eval_output_node_map_ = orig_output_node_map_;
651 if (eval_output_tensor_map_.empty()) eval_output_tensor_map_ = orig_output_tensor_map_;
652 if (eval_output_tensor_names_.empty()) eval_output_tensor_names_ = orig_output_tensor_names_;
653 }
654
CompileTrainOutputs()655 void TrainSession::CompileTrainOutputs() {
656 train_output_node_map_.clear();
657 train_output_tensor_map_.clear();
658 train_output_tensor_names_.clear();
659 for (auto kernel : this->train_kernels_) {
660 if (orig_output_node_map_.find(kernel->name()) == orig_output_node_map_.end()) continue;
661 // Mask out optimizer out tensors
662 if (IsMaskOutput(kernel)) continue;
663 // insert if not already in
664 if (train_output_node_map_.find(kernel->name()) == train_output_node_map_.end()) {
665 auto *ms_tensor = kernel->out_tensors().at(0);
666 if (ms_tensor != nullptr) {
667 train_output_node_map_[kernel->name()].emplace_back(ms_tensor);
668 auto index = TSFindTensor(tensors_, ms_tensor);
669 if (index != tensors_.size()) {
670 if (!ms_tensor->tensor_name().empty()) {
671 train_output_tensor_map_.insert(std::make_pair(ms_tensor->tensor_name(), ms_tensor));
672 train_output_tensor_names_.emplace_back(ms_tensor->tensor_name());
673 } else {
674 train_output_tensor_map_.insert(std::make_pair(std::to_string(index), ms_tensor));
675 train_output_tensor_names_.emplace_back(std::to_string(index));
676 }
677 }
678 }
679 }
680 }
681 if (train_output_node_map_.empty()) train_output_node_map_ = orig_output_node_map_;
682 if (train_output_tensor_map_.empty()) train_output_tensor_map_ = orig_output_tensor_map_;
683 if (train_output_tensor_names_.empty()) train_output_tensor_names_ = orig_output_tensor_names_;
684 }
685
BuildInferenceKernelsRecursive(kernel::KernelExec * kernel,std::vector<kernel::KernelExec * > * v)686 void TrainSession::BuildInferenceKernelsRecursive(kernel::KernelExec *kernel, std::vector<kernel::KernelExec *> *v) {
687 MS_ASSERT(kernel != nullptr);
688 MS_ASSERT(v != nullptr);
689 if (std::find(v->begin(), v->end(), kernel) == v->end()) { // kernel is not already in vector
690 for (auto in_node : kernel->in_kernels()) {
691 BuildInferenceKernelsRecursive(in_node, v);
692 }
693 if (!IsLossKernel(kernel)) v->push_back(kernel);
694 }
695 }
696
CompileTrainKernels()697 void TrainSession::CompileTrainKernels() {
698 train_kernels_.clear();
699 std::vector<kernel::KernelExec *> train_kernels;
700 for (auto ori_kernel : kernels_) {
701 if (ori_kernel->subgraph_type() == kernel::kNotSubGraph) {
702 train_kernels.push_back(ori_kernel);
703 } else {
704 auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(ori_kernel);
705 std::copy(sub_graph->nodes().begin(), sub_graph->nodes().end(), std::back_inserter(train_kernels));
706 }
707 }
708 // For LSTM GPU operators are synchronized internally hence we need to add sync mechanizm to execution graph
709 for (auto k : train_kernels) {
710 if (k->type() == schema::PrimitiveType_LSTMGradWeight) {
711 // Find PrimitiveType_LSTMGradData that matches this PrimitiveType_LSTMGradWeight
712 for (auto mk : train_kernels) {
713 if (mk->type() == schema::PrimitiveType_LSTMGradData) {
714 if (k->in_tensors().at(C2NUM)->tensor_name() == mk->in_tensors().at(0)->tensor_name()) {
715 mk->AddOutKernel(k);
716 k->AddInKernel(mk);
717 }
718 }
719 }
720 }
721 }
722 std::queue<kernel::KernelExec *> queue;
723 for (auto k : train_kernels) {
724 auto in_kernels = k->in_kernels();
725 if (in_kernels.size() == 0) {
726 queue.push(k);
727 }
728 }
729 std::unordered_map<kernel::KernelExec *, int> map;
730 while (!queue.empty()) {
731 // pop first element
732 auto k = queue.front();
733 train_kernels_.push_back(k);
734 queue.pop();
735 for (auto &ok : k->out_kernels()) {
736 auto cnt_iter = map.find(ok);
737 if (cnt_iter == map.end()) {
738 int ref_cnt = ok->in_kernels().size();
739 map[ok] = ref_cnt;
740 }
741 cnt_iter = map.find(ok);
742 auto ref_cnt = cnt_iter->second - 1;
743 map[ok] = ref_cnt;
744 if (ref_cnt <= 0) {
745 queue.push(cnt_iter->first);
746 }
747 }
748 }
749 }
750
CompileInferenceKernels()751 int TrainSession::CompileInferenceKernels() {
752 inference_kernels_.clear();
753 for (auto item : eval_output_node_map_) {
754 std::string kernel_name = item.first;
755 auto kernel = TSFindKernel(train_kernels_, kernel_name);
756 if (kernel == nullptr) {
757 MS_LOG(ERROR) << "kernel is nullptr";
758 return RET_ERROR;
759 }
760 BuildInferenceKernelsRecursive(kernel, &inference_kernels_);
761 }
762
763 if (train_kernels_.size() == inference_kernels_.size()) {
764 MS_LOG(WARNING) << "This is inference model, return err in TrainSession.";
765 return RET_ERROR;
766 }
767
768 if (inference_kernels_.size() == 0) {
769 inference_kernels_ = this->train_kernels_;
770 }
771 return RET_OK;
772 }
773
CompileOptimizedKernels()774 void TrainSession::CompileOptimizedKernels() {
775 std::vector<lite::Tensor *> out_tensor;
776 for (auto kernel : this->train_kernels_) {
777 if (IsOptimizer(kernel)) {
778 std::copy(kernel->in_tensors().begin(), kernel->in_tensors().end(), std::back_inserter(out_tensor));
779 if (cfg_.accumulate_gradients_) {
780 auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
781 auto ret = optimizer->SetOptimizerMode(kernel::WeightUpdateMode::ACCUMULATE_GRADS);
782 if (ret != RET_OK) {
783 MS_LOG(ERROR) << "SetOptimizerMode failed.";
784 return;
785 }
786 }
787 }
788 }
789
790 for (auto kernel : this->train_kernels_) {
791 if (!IsOptimizer(kernel)) {
792 for (auto it : kernel->in_tensors()) {
793 if (std::find(out_tensor.begin(), out_tensor.end(), it) != out_tensor.end()) {
794 kernel->SetTrainable(true);
795 break;
796 }
797 }
798 }
799 }
800 }
801
CompileConstFoldedKernels()802 int TrainSession::CompileConstFoldedKernels() {
803 const_output_tensors_.clear();
804 for (auto kernel : this->inference_kernels_) {
805 bool is_input_const = true;
806 for (auto input : kernel->in_tensors()) {
807 if ((!input->IsConst() || input->IsGraphInput()) &&
808 std::find(const_output_tensors_.begin(), const_output_tensors_.end(), input) == const_output_tensors_.end()) {
809 is_input_const = false;
810 }
811 if (!is_input_const) {
812 const_fold_kernels_.emplace_back(kernel);
813 break;
814 }
815 }
816 if (is_input_const) {
817 auto ret = kernel->Execute();
818 if (RET_OK != ret) {
819 MS_LOG(ERROR) << "run kernel failed, name: " << kernel->name();
820 return ret;
821 }
822 for (auto output : kernel->out_tensors()) {
823 const_output_tensors_.emplace_back(output);
824 output->set_category(Category::CONST_TENSOR);
825 }
826 }
827 }
828 return RET_OK;
829 }
830
FindConstFoldedKernels()831 int TrainSession::FindConstFoldedKernels() {
832 float obf_ratio = ModelRecoverObfuscate(false);
833 if (!FloatCompare(obf_ratio, 1.0)) {
834 MS_LOG(INFO) << "obfuscated model do not need const folding.";
835 const_fold_kernels_ = this->inference_kernels_;
836 const_output_tensors_ = {};
837 return RET_OK;
838 }
839 const_fold_kernels_.clear();
840 const_output_tensors_.clear();
841 for (auto kernel : this->inference_kernels_) {
842 bool is_input_const = true;
843 for (auto input : kernel->in_tensors()) {
844 if ((!input->IsConst() || input->IsGraphInput()) &&
845 std::find(const_output_tensors_.begin(), const_output_tensors_.end(), input) == const_output_tensors_.end()) {
846 is_input_const = false;
847 }
848 if (!is_input_const) {
849 const_fold_kernels_.emplace_back(kernel);
850 break;
851 }
852 }
853 if (is_input_const) {
854 auto ret = kernel->Execute();
855 if (RET_OK != ret) {
856 MS_LOG(ERROR) << "run kernel failed, name: " << kernel->name();
857 return ret;
858 }
859 for (auto output : kernel->out_tensors()) {
860 output->set_category(Category::CONST_TENSOR);
861 const_output_tensors_.emplace_back(output);
862 }
863 }
864 }
865 return RET_OK;
866 }
867
CompileTrainableParams()868 void TrainSession::CompileTrainableParams() {
869 for (auto kernel : this->train_kernels_) {
870 if (!IsOptimizer(kernel)) {
871 continue;
872 }
873 auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
874 auto params = optimizer->GetTrainableParams();
875 auto in_kernels = kernel->in_kernels();
876 AddNonConstTrainableParams(in_kernels, optimizer, ¶ms);
877
878 for (auto param : params) {
879 if (std::find(trainable_parameters_.begin(), trainable_parameters_.end(), param) != trainable_parameters_.end()) {
880 continue;
881 }
882 trainable_parameters_.emplace_back(param);
883 }
884 }
885 }
886
SetLearningRate(float learning_rate)887 int TrainSession::SetLearningRate(float learning_rate) {
888 if (learning_rate < 0.0f) {
889 MS_LOG(ERROR) << "learning rate should more than 0";
890 return RET_ERROR;
891 }
892 for (auto kernel : this->train_kernels_) {
893 if (IsOptimizer(kernel)) {
894 auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
895 auto ret = optimizer->SetLearningRate(learning_rate);
896 if (ret != RET_OK) {
897 MS_LOG(ERROR) << kernel->name() << " failed to set learning rate";
898 return RET_ERROR;
899 }
900 }
901 }
902 return RET_OK;
903 }
904
GetLearningRate()905 float TrainSession::GetLearningRate() {
906 for (auto kernel : this->train_kernels_) {
907 if (IsOptimizer(kernel)) {
908 auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
909 return optimizer->GetLearningRate();
910 }
911 }
912 return 0.0;
913 }
914
GetOptimizerParams() const915 std::vector<lite::Tensor *> TrainSession::GetOptimizerParams() const {
916 std::vector<lite::Tensor *> params;
917 for (auto kernel : this->train_kernels_) {
918 if (IsOptimizer(kernel)) {
919 auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
920 auto kernelParams = optimizer->GetOptimizerParams();
921 for (size_t ix = 0; ix < kernelParams.size(); ix++) {
922 auto kernelParam = kernelParams[ix];
923 auto name = kernelParam->tensor_name();
924 bool found = false;
925 for (size_t iy = 0; iy < params.size(); iy++) {
926 if (params[iy]->tensor_name() == name) {
927 found = true;
928 break;
929 }
930 }
931 if (!found) {
932 params.push_back(kernelParam);
933 }
934 }
935 }
936 }
937 return params;
938 }
939
SetOptimizerParams(const std::vector<lite::Tensor * > & params)940 int TrainSession::SetOptimizerParams(const std::vector<lite::Tensor *> ¶ms) {
941 for (size_t ix = 0; ix < params.size(); ix++) {
942 auto param = params[ix];
943 if (param == nullptr) {
944 MS_LOG(ERROR) << "Param tensor is null.";
945 return RET_ERROR;
946 }
947 bool found = false;
948 for (auto kernel : this->train_kernels_) {
949 if (IsOptimizer(kernel)) {
950 auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
951 found = optimizer->SetOptimizerParams(param);
952 if (found) break;
953 }
954 }
955 if (!found) {
956 MS_LOG(ERROR) << "Tensor " << param->tensor_name() << " with " << param->ElementsNum() << " elelmts and type "
957 << param->data_type() << " is not a valid params tensor";
958 return RET_ERROR;
959 }
960 }
961 return RET_OK;
962 }
963
GetGradients() const964 std::vector<lite::Tensor *> TrainSession::GetGradients() const {
965 std::vector<lite::Tensor *> gradients;
966 for (auto kernel : this->train_kernels_) {
967 if (IsOptimizer(kernel)) {
968 auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
969 auto kernelGradint = optimizer->GetGradients();
970 if (kernelGradint != nullptr) {
971 gradients.push_back(kernelGradint);
972 }
973 }
974 }
975 return gradients;
976 }
977
ApplyGradients(const std::vector<lite::Tensor * > & gradients)978 int TrainSession::ApplyGradients(const std::vector<lite::Tensor *> &gradients) {
979 auto current_gradients = GetGradients();
980 if (current_gradients.size() != gradients.size()) {
981 MS_LOG(ERROR) << "gradients vector has wrong size " << gradients.size() << " instead of "
982 << current_gradients.size();
983 FreeGradients(current_gradients);
984 return RET_ERROR;
985 }
986 for (size_t ix = 0; ix < gradients.size(); ix++) {
987 auto gradient = gradients[ix];
988 if (gradient == nullptr) {
989 MS_LOG(ERROR) << "gradient tensor is null.";
990 FreeGradients(current_gradients);
991 return RET_ERROR;
992 }
993 bool found = false;
994 for (size_t iy = 0; iy < current_gradients.size(); iy++) {
995 auto current_gradient = current_gradients[iy];
996 if (current_gradient->tensor_name() == gradient->tensor_name()) {
997 found = true;
998 if (current_gradient->Size() == gradient->Size()) {
999 std::copy(static_cast<uint8_t *>(gradient->data()),
1000 static_cast<uint8_t *>(gradient->data()) + gradient->Size(),
1001 static_cast<uint8_t *>(current_gradient->MutableData()));
1002 } else {
1003 MS_LOG(ERROR) << "gradient tensor " << gradient->tensor_name() << " has wrong size " << gradient->Size()
1004 << " instead of " << current_gradient->Size();
1005 FreeGradients(current_gradients);
1006 return RET_ERROR;
1007 }
1008 break;
1009 }
1010 }
1011 if (!found) {
1012 MS_LOG(ERROR) << "gradient tensor " << gradient->tensor_name() << " not found";
1013 FreeGradients(current_gradients);
1014 return RET_ERROR;
1015 }
1016 }
1017 FreeGradients(current_gradients);
1018 for (auto kernel : this->train_kernels_) {
1019 if (IsOptimizer(kernel)) {
1020 auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
1021 if (optimizer->set_grad_sum_valid() != RET_OK) {
1022 MS_LOG(ERROR) << "set grad sum valid failed.";
1023 return RET_ERROR;
1024 }
1025 auto ret = optimizer->OptimizerStep();
1026 if (ret != RET_OK) {
1027 MS_LOG(ERROR) << "failed to optimize model weights";
1028 return ret;
1029 }
1030 }
1031 }
1032 return RET_OK;
1033 }
1034
AdminSetupVirtualBatch(int virtual_batch_multiplier,float lr,float momentum)1035 int TrainSession::AdminSetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum) {
1036 auto mod =
1037 (virtual_batch_multiplier <= 1) ? kernel::WeightUpdateMode::NORMAL : kernel::WeightUpdateMode::VIRTUAL_BATCH;
1038 virtual_batch_multiplier_ = (virtual_batch_multiplier <= 1) ? 0 : virtual_batch_multiplier;
1039 virtual_batch_idx_ = 0;
1040
1041 for (auto kernel : this->train_kernels_) {
1042 if (IsOptimizer(kernel)) {
1043 auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
1044 if (optimizer->get_optimizer_mode() != kernel::WeightUpdateMode::NORMAL &&
1045 optimizer->get_optimizer_mode() != kernel::WeightUpdateMode::VIRTUAL_BATCH) {
1046 MS_LOG(ERROR) << kernel->name() << " failed to set optimizer mode, conflict with accumulate grads";
1047 return RET_ERROR;
1048 }
1049 auto ret = optimizer->SetOptimizerMode(mod);
1050 if (ret != RET_OK) {
1051 MS_LOG(ERROR) << kernel->name() << " failed to set optimizer mode";
1052 return RET_ERROR;
1053 }
1054 if (mod == kernel::WeightUpdateMode::VIRTUAL_BATCH) {
1055 lr = (lr < 0.0f) ? (optimizer->GetLearningRate() / static_cast<float>(virtual_batch_multiplier_)) : lr;
1056 ret = optimizer->SetLearningRate(lr);
1057 } else {
1058 ret = optimizer->RestoreDefaultLearningRate();
1059 }
1060 if (ret != RET_OK) {
1061 MS_LOG(ERROR) << kernel->name() << " failed to set learning rate";
1062 return RET_ERROR;
1063 }
1064 }
1065
1066 if (IsBN(kernel) && kernel->IsTrainable()) {
1067 auto batchnorm = static_cast<kernel::LiteKernel *>(kernel->kernel());
1068 auto ret = batchnorm->SetupVirtualBatch(virtual_batch_multiplier_, momentum);
1069 if (ret != RET_OK) {
1070 MS_LOG(ERROR) << kernel->name() << " failed to set momentum";
1071 return RET_ERROR;
1072 }
1073 }
1074 }
1075 return RET_OK;
1076 }
SetupVirtualBatch(int virtual_batch_multiplier,float lr,float momentum)1077 int TrainSession::SetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum) {
1078 int tmp = (virtual_batch_multiplier <= 1) ? 0 : virtual_batch_multiplier;
1079 if (tmp != 0 && virtual_batch_multiplier_ != 0) {
1080 AdminSetupVirtualBatch(0, lr, momentum);
1081 }
1082 return AdminSetupVirtualBatch(virtual_batch_multiplier, lr, momentum);
1083 }
1084
OptimizerStep()1085 int TrainSession::OptimizerStep() {
1086 for (auto kernel : this->train_kernels_) {
1087 if (IsOptimizer(kernel)) {
1088 auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
1089 auto ret = optimizer->OptimizerStep();
1090 if (ret != RET_OK) {
1091 MS_LOG(ERROR) << kernel->name() << " failed to do optimize step";
1092 return RET_ERROR;
1093 }
1094 }
1095 }
1096 return RET_OK;
1097 }
1098
IsLossInKernel(const kernel::KernelExec * kernel) const1099 bool TrainSession::IsLossInKernel(const kernel::KernelExec *kernel) const {
1100 if (IsLossKernel(kernel) || IsGradKernel(kernel)) {
1101 return true;
1102 }
1103 for (auto in_kernel : kernel->in_kernels()) {
1104 if (IsLossKernel(in_kernel)) {
1105 return true;
1106 }
1107 }
1108 return false;
1109 }
1110
IsLossKernel(const kernel::KernelExec * kernel) const1111 bool TrainSession::IsLossKernel(const kernel::KernelExec *kernel) const {
1112 bool isLoss = false;
1113 for (auto &s : cfg_.loss_name_) {
1114 if (kernel->name().find(s) != std::string::npos) {
1115 isLoss = true;
1116 break;
1117 }
1118 }
1119 return isLoss;
1120 }
1121
IsGradKernel(const kernel::KernelExec * kernel) const1122 bool TrainSession::IsGradKernel(const kernel::KernelExec *kernel) const {
1123 return kernel->name().find(kGradName) != std::string::npos;
1124 }
1125
IsOptimizer(kernel::KernelExec * kernel) const1126 bool TrainSession::IsOptimizer(kernel::KernelExec *kernel) const {
1127 return ((kernel->type() == schema::PrimitiveType_Adam) || (kernel->type() == schema::PrimitiveType_SGD) ||
1128 (kernel->type() == schema::PrimitiveType_ApplyMomentum));
1129 }
1130
IsMaskOutput(kernel::KernelExec * kernel) const1131 bool TrainSession::IsMaskOutput(kernel::KernelExec *kernel) const {
1132 return (IsOptimizer(kernel) || (kernel->type() == schema::PrimitiveType_Assign));
1133 }
1134
IsBN(kernel::KernelExec * kernel) const1135 bool TrainSession::IsBN(kernel::KernelExec *kernel) const {
1136 return ((kernel->type() == schema::PrimitiveType_BatchNorm) ||
1137 (kernel->type() == schema::PrimitiveType_FusedBatchNorm));
1138 }
1139
Resize(const std::vector<lite::Tensor * > & inputs,const std::vector<std::vector<int>> & dims)1140 int TrainSession::Resize(const std::vector<lite::Tensor *> &inputs, const std::vector<std::vector<int>> &dims) {
1141 FreeWorkSpace();
1142 if (tensors_data_ != nullptr) {
1143 free(tensors_data_);
1144 tensors_data_ = nullptr;
1145 }
1146 auto ret = lite::LiteSession::Resize(inputs, dims);
1147 if (ret != RET_OK) {
1148 MS_LOG(ERROR) << "train resize input failed.";
1149 return RET_ERROR;
1150 }
1151 ret = AllocWorkSpace();
1152 if (ret != RET_OK) {
1153 MS_LOG(ERROR) << "failed to allocate space";
1154 return RET_ERROR;
1155 }
1156 ret = AllocTensors(train_kernels_);
1157 if (ret != RET_OK) {
1158 MS_LOG(ERROR) << "train alloc failed after resize.";
1159 return RET_ERROR;
1160 }
1161 return RET_OK;
1162 }
1163
FindUseInTensorKernel(std::vector<kernel::KernelExec * > * use_in_tensor_kernels,const std::vector<lite::Tensor * > & kernel_in_tensors,const std::vector<kernel::KernelExec * > & inference_kernels)1164 int TrainSession::FindUseInTensorKernel(std::vector<kernel::KernelExec *> *use_in_tensor_kernels,
1165 const std::vector<lite::Tensor *> &kernel_in_tensors,
1166 const std::vector<kernel::KernelExec *> &inference_kernels) {
1167 for (size_t i = 0; i < inference_kernels.size(); i++) {
1168 for (size_t j = 0; j < kernel_in_tensors.size(); j++) {
1169 if (IsContain(inference_kernels[i]->out_tensors(), kernel_in_tensors[j])) {
1170 use_in_tensor_kernels->push_back(inference_kernels[i]);
1171 }
1172 }
1173 }
1174 return RET_OK;
1175 }
1176
FindExportKernels(std::vector<kernel::KernelExec * > * export_kernels,const std::vector<std::string> & export_output_tensor_names,const std::vector<kernel::KernelExec * > & inference_kernels)1177 int TrainSession::FindExportKernels(std::vector<kernel::KernelExec *> *export_kernels,
1178 const std::vector<std::string> &export_output_tensor_names,
1179 const std::vector<kernel::KernelExec *> &inference_kernels) {
1180 std::vector<std::string> all_kernel_name = {};
1181 (void)std::transform(inference_kernels.begin(), inference_kernels.end(), std::back_inserter(all_kernel_name),
1182 [](kernel::KernelExec *kernel) { return kernel->name(); });
1183 std::queue<std::string> need_kernel_names;
1184 // Find the kernel name according to the tensor name
1185 for (auto &kernel : inference_kernels) {
1186 if (std::any_of(kernel->out_tensors().begin(), kernel->out_tensors().end(), [&](lite::Tensor *out_tensor) {
1187 return IsContain(export_output_tensor_names, out_tensor->tensor_name());
1188 })) {
1189 need_kernel_names.push(kernel->name());
1190 }
1191 }
1192 if (need_kernel_names.size() == 0) {
1193 MS_LOG(ERROR) << "can not find tensor";
1194 return RET_ERROR;
1195 }
1196 // find all kernel
1197 while (!need_kernel_names.empty()) {
1198 auto kernel_name = need_kernel_names.front();
1199 need_kernel_names.pop();
1200 auto it = find(all_kernel_name.begin(), all_kernel_name.end(), kernel_name);
1201 if (it == all_kernel_name.end()) {
1202 MS_LOG(ERROR) << "not find kernel name in export trained model.";
1203 return RET_ERROR;
1204 }
1205 auto kernel = inference_kernels[it - all_kernel_name.begin()];
1206 if (!IsContain(*export_kernels, kernel)) {
1207 export_kernels->push_back(kernel);
1208 }
1209 auto kernel_in_tensors = kernel->in_tensors();
1210 std::vector<kernel::KernelExec *> use_in_tensor_kernels;
1211 auto status = FindUseInTensorKernel(&use_in_tensor_kernels, kernel_in_tensors, inference_kernels);
1212 if (status != RET_OK) {
1213 MS_LOG(ERROR) << "FindUseInTensorKernel failed.";
1214 return RET_ERROR;
1215 }
1216 for (size_t i = 0; i < use_in_tensor_kernels.size(); i++) {
1217 need_kernel_names.push(use_in_tensor_kernels[i]->name());
1218 }
1219 }
1220 return RET_OK;
1221 }
1222
1223 template <typename DestType>
ExportInner(DestType destination,ModelType model_type,QuantizationType quant_type,FormatType format,std::vector<std::string> out_put_tensor_name)1224 int TrainSession::ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type,
1225 FormatType format, std::vector<std::string> out_put_tensor_name) {
1226 if constexpr (std::is_same_v<DestType, const std::string &>) {
1227 MS_CHECK_FALSE_MSG(destination.empty(), RET_ERROR, "File name cannot be empty");
1228 struct stat path_type;
1229 if (stat(destination.c_str(), &path_type) == RET_OK) {
1230 if (path_type.st_mode & S_IFDIR) {
1231 MS_LOG(ERROR) << "Destination must be path, now is a directory";
1232 return RET_ERROR;
1233 }
1234 }
1235 } else if constexpr (std::is_same_v<DestType, Buffer *>) {
1236 MS_CHECK_FALSE_MSG(destination == nullptr, RET_ERROR, "model buffer cannot be nullptr");
1237 } else {
1238 MS_LOG(ERROR) << "Unsupported destination.";
1239 return RET_ERROR;
1240 }
1241 MS_CHECK_FALSE_MSG(model_type > mindspore::lite::MT_INFERENCE || model_type < mindspore::lite::MT_TRAIN, RET_ERROR,
1242 "Export model type parameter error");
1243 MS_CHECK_FALSE_MSG(quant_type < mindspore::lite::QT_DEFAULT || quant_type > mindspore::lite::QT_WEIGHT, RET_ERROR,
1244 "Export quant type parameter error");
1245 MS_CHECK_FALSE_MSG(format != FT_FLATBUFFERS, RET_ERROR, "File name cannot be empty");
1246
1247 bool orig_train_state = IsTrain();
1248 TrainExport texport(destination);
1249 int status = texport.ExportInit(model_.get()->graph_.name_, model_.get()->graph_.version_);
1250 TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to init export");
1251
1252 if (!out_put_tensor_name.empty() && model_type == MT_INFERENCE) {
1253 std::vector<kernel::KernelExec *> export_kernels = {};
1254 status = FindExportKernels(&export_kernels, out_put_tensor_name, const_fold_kernels_);
1255 TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "FindExportKernels failed.");
1256 status =
1257 texport.ExportNet(export_kernels, tensors_, const_output_tensors_, out_put_tensor_name, model_.get(), quant_type);
1258 } else {
1259 if ((!model_buff_changed_) && (quant_type == QT_NONE) && (model_type == MT_TRAIN) &&
1260 std::all_of(model_->graph_.all_nodes_.begin(), model_->graph_.all_nodes_.end(), [](const LiteGraph::Node *n) {
1261 return n->quant_type_ == schema::QuantType::QuantType_QUANT_NONE;
1262 })) {
1263 status = texport.SaveModel(model_.get(), destination);
1264 if (orig_train_state) Train();
1265 return status;
1266 } else {
1267 if (quant_type == QT_NONE) {
1268 status = texport.ExportNet(
1269 (model_type == MT_TRAIN) ? train_kernels_ : const_fold_kernels_, tensors_, const_output_tensors_,
1270 (model_type == MT_TRAIN) ? train_output_tensor_names_ : eval_output_tensor_names_, model_.get(), quant_type);
1271 } else {
1272 status = texport.ExportNet((model_type == MT_TRAIN) ? train_kernels_ : inference_kernels_, tensors_, {},
1273 (model_type == MT_TRAIN) ? train_output_tensor_names_ : eval_output_tensor_names_,
1274 model_.get(), quant_type);
1275 }
1276 }
1277 }
1278 TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to export Network.");
1279
1280 if (model_type == MT_INFERENCE) {
1281 status = texport.TrainModelDrop();
1282 TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelDrop failed.");
1283 status = texport.TrainModelFusion();
1284 TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelFusion failed.");
1285 }
1286 if constexpr (std::is_same_v<DestType, const std::string &>) {
1287 status = texport.SaveToFile();
1288 if (status != RET_OK) {
1289 MS_LOG(ERROR) << "failed to save to " << destination;
1290 return status;
1291 }
1292 } else {
1293 status = texport.SaveToBuffer();
1294 TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "fail to save to model buffer.");
1295 }
1296 if (orig_train_state) Train();
1297 return status;
1298 }
1299
Export(const std::string & file_name,ModelType model_type,QuantizationType quant_type,FormatType format,std::vector<std::string> out_put_tensor_name)1300 int TrainSession::Export(const std::string &file_name, ModelType model_type, QuantizationType quant_type,
1301 FormatType format, std::vector<std::string> out_put_tensor_name) {
1302 return ExportInner<const std::string &>(file_name, model_type, quant_type, format, out_put_tensor_name);
1303 }
1304
Export(Buffer * model_buffer,ModelType model_type,QuantizationType quant_type,FormatType format,std::vector<std::string> out_put_tensor_name)1305 int TrainSession::Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType format,
1306 std::vector<std::string> out_put_tensor_name) {
1307 return ExportInner<Buffer *>(model_buffer, model_type, quant_type, format, out_put_tensor_name);
1308 }
1309
ExportWeightsCollaborateWithMicro(const std::string & file_name,lite::ModelType model_type,FormatType format,bool enable_fp16,const std::vector<std::string> & changeable_weights_name)1310 int TrainSession::ExportWeightsCollaborateWithMicro(const std::string &file_name, lite::ModelType model_type,
1311 FormatType format, bool enable_fp16,
1312 const std::vector<std::string> &changeable_weights_name) {
1313 MS_CHECK_FALSE_MSG(file_name.empty(), RET_ERROR, "File name cannot be empty");
1314 struct stat path_type;
1315 if (stat(file_name.c_str(), &path_type) == RET_OK) {
1316 if (path_type.st_mode & S_IFDIR) {
1317 MS_LOG(ERROR) << "Destination must be path, now is a directory";
1318 return RET_ERROR;
1319 }
1320 }
1321 MS_CHECK_FALSE_MSG(format != FT_FLATBUFFERS, RET_ERROR, "File name cannot be empty");
1322 MS_CHECK_FALSE_MSG(model_type != mindspore::lite::MT_INFERENCE, RET_ERROR,
1323 "Currently, can only export inference-model's weights.");
1324
1325 TrainExport texport(file_name);
1326 auto status = texport.ExportInit(model_.get()->graph_.name_, model_.get()->graph_.version_);
1327 TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to init export");
1328 // Find and prepare a list of kernels which are const folded
1329 MS_CHECK_TRUE_MSG(FindConstFoldedKernels() == RET_OK, RET_ERROR, "FindConstFoldedKernels failed.");
1330 status = texport.ExportNet(const_fold_kernels_, tensors_, const_output_tensors_, eval_output_tensor_names_,
1331 model_.get(), QT_DEFAULT);
1332 TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to export Network.");
1333 status = texport.TrainModelDrop();
1334 TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelDrop failed.");
1335 status = texport.TrainModelFusion();
1336 TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelFusion failed.");
1337 status = texport.SaveWeightsToFile(enable_fp16, changeable_weights_name);
1338 if (status != RET_OK) {
1339 MS_LOG(ERROR) << "failed to save to " << file_name;
1340 return status;
1341 }
1342 return RET_OK;
1343 }
1344
GetFeatureMaps() const1345 std::vector<lite::Tensor *> TrainSession::GetFeatureMaps() const {
1346 std::vector<lite::Tensor *> features;
1347 for (auto cur_tensor : this->tensors_) {
1348 if (cur_tensor->category() == lite::Category::CONST_TENSOR && cur_tensor->data_type() == kNumberTypeFloat32) {
1349 features.push_back(cur_tensor);
1350 }
1351 }
1352 return features;
1353 }
1354
GetTrainableParams() const1355 std::vector<lite::Tensor *> TrainSession::GetTrainableParams() const { return trainable_parameters_; }
1356
UpdateFeatureMaps(const std::vector<lite::Tensor * > & features_map)1357 int TrainSession::UpdateFeatureMaps(const std::vector<lite::Tensor *> &features_map) {
1358 for (auto feature : features_map) {
1359 bool find = false;
1360 for (auto tensor : tensors_) {
1361 if (!tensor->IsConst() || tensor->data_type() != kNumberTypeFloat32) {
1362 continue;
1363 }
1364 if (feature->tensor_name() != tensor->tensor_name()) {
1365 continue;
1366 }
1367 if (feature->Size() != tensor->Size()) {
1368 MS_LOG(ERROR) << "feature name:" << feature->tensor_name() << ",len diff:"
1369 << "old is:" << tensor->Size() << "new is:" << feature->Size();
1370 return RET_ERROR;
1371 }
1372 find = true;
1373 memcpy(tensor->data(), feature->data(), tensor->Size());
1374 }
1375 if (!find) {
1376 MS_LOG(ERROR) << "cannot find feature:" << feature->tensor_name() << ",update failed";
1377 return RET_ERROR;
1378 }
1379 }
1380 return RET_OK;
1381 }
1382
1383 std::set<schema::PrimitiveType> inPlaceSupportedKernels = {
1384 mindspore::schema::PrimitiveType_Activation, mindspore::schema::PrimitiveType_ActivationGrad,
1385 mindspore::schema::PrimitiveType_Reshape, mindspore::schema::PrimitiveType_DivFusion,
1386 mindspore::schema::PrimitiveType_AddFusion, mindspore::schema::PrimitiveType_SubFusion,
1387 mindspore::schema::PrimitiveType_RealDiv, mindspore::schema::PrimitiveType_Select,
1388 mindspore::schema::PrimitiveType_BiasAdd, mindspore::schema::PrimitiveType_BiasAddGrad,
1389 mindspore::schema::PrimitiveType_Sqrt, mindspore::schema::PrimitiveType_Abs,
1390 mindspore::schema::PrimitiveType_Cos, mindspore::schema::PrimitiveType_Log,
1391 mindspore::schema::PrimitiveType_Square, mindspore::schema::PrimitiveType_Rsqrt,
1392 mindspore::schema::PrimitiveType_Sin, mindspore::schema::PrimitiveType_LogicalNot,
1393 mindspore::schema::PrimitiveType_LogicalAnd, mindspore::schema::PrimitiveType_LogicalOr,
1394 mindspore::schema::PrimitiveType_Floor, mindspore::schema::PrimitiveType_Ceil,
1395 mindspore::schema::PrimitiveType_Round, mindspore::schema::PrimitiveType_Neg,
1396 mindspore::schema::PrimitiveType_Reciprocal, mindspore::schema::PrimitiveType_Erf,
1397 mindspore::schema::PrimitiveType_Maximum, mindspore::schema::PrimitiveType_Minimum,
1398 mindspore::schema::PrimitiveType_FloorDiv, mindspore::schema::PrimitiveType_FloorMod,
1399 mindspore::schema::PrimitiveType_Eltwise, mindspore::schema::PrimitiveType_SquaredDifference,
1400 mindspore::schema::PrimitiveType_ExpandDims, mindspore::schema::PrimitiveType_Cast,
1401 mindspore::schema::PrimitiveType_Flatten, mindspore::schema::PrimitiveType_FlattenGrad,
1402 mindspore::schema::PrimitiveType_Squeeze, mindspore::schema::PrimitiveType_Unsqueeze};
1403
IsInPlaceKernel(kernel::KernelExec * kernel)1404 bool TrainSession::IsInPlaceKernel(kernel::KernelExec *kernel) {
1405 if (inPlaceSupportedKernels.find(kernel->type()) != inPlaceSupportedKernels.end() &&
1406 !(kernel->type() == mindspore::schema::PrimitiveType_Activation &&
1407 kernel->op_parameter()->type_ == schema::ActivationType_SIGMOID)) {
1408 return true;
1409 }
1410 return false;
1411 }
1412
IsInPlaceTensor(kernel::KernelExec * kernel,uint32_t idx,const std::unordered_map<lite::Tensor *,int> & ref_count,uint32_t * input_idx)1413 bool TrainSession::IsInPlaceTensor(kernel::KernelExec *kernel, uint32_t idx,
1414 const std::unordered_map<lite::Tensor *, int> &ref_count, uint32_t *input_idx) {
1415 if (IsInPlaceKernel(kernel)) {
1416 auto out_tensor = kernel->out_tensors().at(idx);
1417 for (size_t i = 0; i < kernel->in_tensors().size(); i++) {
1418 auto tensor = kernel->in_tensors().at(i);
1419 if ((tensor->category() == lite::Category::VAR) && (ref_count.find(tensor) != ref_count.end()) &&
1420 (tensor->init_ref_count() == 1 || (tensor->init_ref_count() > 1 && ref_count.at(tensor) == 1)) &&
1421 (out_tensor->Size() == tensor->Size())) {
1422 *input_idx = static_cast<uint32_t>(i);
1423 return true;
1424 }
1425 }
1426 }
1427 return false;
1428 }
1429
GetInplaceTensorOffset(kernel::KernelExec * kernel,const std::unordered_map<lite::Tensor *,size_t> & offset_map,std::unordered_map<lite::Tensor *,int> * ref_count,uint32_t input_idx)1430 size_t TrainSession::GetInplaceTensorOffset(kernel::KernelExec *kernel,
1431 const std::unordered_map<lite::Tensor *, size_t> &offset_map,
1432 std::unordered_map<lite::Tensor *, int> *ref_count, uint32_t input_idx) {
1433 auto tensor = kernel->in_tensors().at(input_idx);
1434 ref_count->at(tensor) = ref_count->at(tensor) + 1;
1435 return offset_map.at(tensor);
1436 }
1437
FindObfTensor()1438 lite::Tensor *TrainSession::FindObfTensor() {
1439 for (auto node : model_->graph_.all_nodes_) {
1440 if (node->name_.find(kObfNodeName) != std::string::npos) {
1441 auto idx = node->input_indices_[kDataIndex];
1442 return tensors_[idx];
1443 }
1444 }
1445 return nullptr;
1446 }
1447
ChangeObfWeight(std::string tensor_name,float obf_ratio)1448 int TrainSession::ChangeObfWeight(std::string tensor_name, float obf_ratio) {
1449 float data[1] = {obf_ratio};
1450 auto new_tensor = lite::Tensor::CreateTensor(tensor_name, TypeId::kNumberTypeFloat32, {1, 1}, data, kFloatSize);
1451 std::vector<lite::Tensor *> modify_tensors;
1452 if (new_tensor == nullptr) {
1453 MS_LOG(ERROR) << "Create tensor failed";
1454 return RET_ERROR;
1455 }
1456 modify_tensors.emplace_back(new_tensor);
1457 auto ret = this->UpdateWeights(modify_tensors);
1458 if (ret != kSuccess) {
1459 MS_LOG(ERROR) << "UpdateWeights failed.";
1460 return RET_ERROR;
1461 }
1462 return RET_OK;
1463 }
1464
ModelRecoverObfuscate(bool change_weight)1465 float TrainSession::ModelRecoverObfuscate(bool change_weight) {
1466 float true_obf_ratio = 1.0;
1467 auto tensor = FindObfTensor();
1468 if (tensor != nullptr) {
1469 std::string tensor_name = tensor->tensor_name();
1470 true_obf_ratio = *(reinterpret_cast<float *>(tensor->data()));
1471 if (!change_weight) {
1472 return true_obf_ratio;
1473 }
1474 float init_obf_ratio = 1.0;
1475 ChangeObfWeight(tensor_name, init_obf_ratio);
1476 }
1477 return true_obf_ratio;
1478 }
1479
ModelDeObfuscate(float obf_ratio)1480 int TrainSession::ModelDeObfuscate(float obf_ratio) {
1481 if (!FloatCompare(obf_ratio, 0.0)) {
1482 auto *tensor = FindObfTensor();
1483 if (tensor != nullptr) {
1484 std::string tensor_name = tensor->tensor_name();
1485 return ChangeObfWeight(tensor_name, obf_ratio);
1486 }
1487 MS_LOG(ERROR) << "Obfuscate tensor is null";
1488 return RET_ERROR;
1489 }
1490 MS_LOG(ERROR) << "Obfuscate value is 0";
1491 return RET_ERROR;
1492 }
1493 } // namespace lite
1494 } // namespace mindspore
1495