1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/benchmark_train/net_train.h"
18 #define __STDC_FORMAT_MACROS
19 #include <cinttypes>
20 #undef __STDC_FORMAT_MACROS
21 #include <algorithm>
22 #include <cstring>
23 #include <utility>
24 #ifdef ENABLE_NEON
25 #include <arm_neon.h>
26 #endif
27 #include "src/common/common.h"
28 #include "include/ms_tensor.h"
29 #include "include/context.h"
30 #include "include/version.h"
31 #include "include/model.h"
32 #include "include/train/train_cfg.h"
33 #include "include/train/train_session.h"
34
35 namespace mindspore {
36 namespace lite {
37 static const char *DELIM_SLASH = "/";
38 constexpr const char *DELIM_COLON = ":";
39 constexpr const char *DELIM_COMMA = ",";
40 constexpr int RET_TOO_BIG = -9;
41 constexpr int kField0 = 0;
42 constexpr int kField1 = 1;
43 constexpr int kField2 = 2;
44 constexpr int kField3 = 3;
45 constexpr int kField4 = 4;
46 constexpr int kFieldsToPrint = 5;
47 constexpr int kPrintOffset = 4;
48 constexpr int kCPUBindFlag2 = 2;
49 constexpr int kCPUBindFlag1 = 1;
50 static const int kTHOUSAND = 1000;
51
52 namespace {
ReadFileBuf(const char * file,size_t * size)53 float *ReadFileBuf(const char *file, size_t *size) {
54 if (file == nullptr) {
55 MS_LOG(ERROR) << "file is nullptr";
56 return nullptr;
57 }
58 MS_ASSERT(size != nullptr);
59 std::string real_path = RealPath(file);
60 std::ifstream ifs(real_path);
61 if (!ifs.good()) {
62 MS_LOG(ERROR) << "file: " << real_path << " is not exist";
63 return nullptr;
64 }
65
66 if (!ifs.is_open()) {
67 MS_LOG(ERROR) << "file: " << real_path << " open failed";
68 return nullptr;
69 }
70
71 ifs.seekg(0, std::ios::end);
72 *size = ifs.tellg();
73 std::unique_ptr<float[]> buf = std::make_unique<float[]>(*size / sizeof(float) + 1);
74 if (buf == nullptr) {
75 MS_LOG(ERROR) << "malloc buf failed, file: " << real_path;
76 ifs.close();
77 return nullptr;
78 }
79
80 ifs.seekg(0, std::ios::beg);
81 ifs.read(reinterpret_cast<char *>(buf.get()), *size);
82 ifs.close();
83
84 return buf.release();
85 }
86 } // namespace
87
GenerateRandomData(size_t size,void * data)88 int NetTrain::GenerateRandomData(size_t size, void *data) {
89 MS_ASSERT(data != nullptr);
90 char *casted_data = static_cast<char *>(data);
91 for (size_t i = 0; i < size; i++) {
92 casted_data[i] = static_cast<char>(i);
93 }
94 return RET_OK;
95 }
96
GenerateInputData(std::vector<mindspore::tensor::MSTensor * > * ms_inputs)97 int NetTrain::GenerateInputData(std::vector<mindspore::tensor::MSTensor *> *ms_inputs) {
98 for (auto tensor : *ms_inputs) {
99 MS_ASSERT(tensor != nullptr);
100 auto input_data = tensor->MutableData();
101 if (input_data == nullptr) {
102 MS_LOG(ERROR) << "MallocData for inTensor failed";
103 return RET_ERROR;
104 }
105 auto tensor_byte_size = tensor->Size();
106 auto status = GenerateRandomData(tensor_byte_size, input_data);
107 if (status != RET_OK) {
108 std::cerr << "GenerateRandomData for inTensor failed: " << status << std::endl;
109 MS_LOG(ERROR) << "GenerateRandomData for inTensor failed: " << status;
110 return status;
111 }
112 }
113 return RET_OK;
114 }
115
LoadInput(std::vector<mindspore::tensor::MSTensor * > * ms_inputs)116 int NetTrain::LoadInput(std::vector<mindspore::tensor::MSTensor *> *ms_inputs) {
117 if (flags_->in_data_file_.empty()) {
118 auto status = GenerateInputData(ms_inputs);
119 if (status != RET_OK) {
120 std::cerr << "Generate input data error " << status << std::endl;
121 MS_LOG(ERROR) << "Generate input data error " << status;
122 return status;
123 }
124 } else {
125 auto status = ReadInputFile(ms_inputs);
126 if (status != RET_OK) {
127 std::cerr << "Read Input File error, " << status << std::endl;
128 MS_LOG(ERROR) << "Read Input File error, " << status;
129 return status;
130 }
131 }
132 return RET_OK;
133 }
134
ReadInputFile(std::vector<mindspore::tensor::MSTensor * > * ms_inputs)135 int NetTrain::ReadInputFile(std::vector<mindspore::tensor::MSTensor *> *ms_inputs) {
136 if (ms_inputs->empty()) {
137 return RET_OK;
138 }
139
140 if (this->flags_->in_data_type_ == kImage) {
141 MS_LOG(ERROR) << "Not supported image input";
142 return RET_ERROR;
143 } else {
144 for (size_t i = 0; i < ms_inputs->size(); i++) {
145 auto cur_tensor = ms_inputs->at(i);
146 MS_ASSERT(cur_tensor != nullptr);
147 size_t size;
148 std::string file_name = flags_->in_data_file_ + std::to_string(i + 1) + ".bin";
149 auto bin_buf = ReadFile(file_name.c_str(), &size);
150 if (bin_buf == nullptr) {
151 MS_LOG(ERROR) << "ReadFile return nullptr";
152 return RET_ERROR;
153 }
154 auto tensor_data_size = cur_tensor->Size();
155 if (size != tensor_data_size) {
156 std::cerr << "Input binary file size error, required: " << tensor_data_size << ", in fact: " << size
157 << std::endl;
158 MS_LOG(ERROR) << "Input binary file size error, required: " << tensor_data_size << ", in fact: " << size;
159 delete bin_buf;
160 return RET_ERROR;
161 }
162 auto input_data = cur_tensor->MutableData();
163 memcpy(input_data, bin_buf, tensor_data_size);
164 delete[](bin_buf);
165 }
166 }
167 return RET_OK;
168 }
169
CompareOutput(const session::LiteSession & lite_session)170 int NetTrain::CompareOutput(const session::LiteSession &lite_session) {
171 std::cout << "================ Comparing Forward Output data ================" << std::endl;
172 float total_bias = 0;
173 int total_size = 0;
174 bool has_error = false;
175 auto tensors_list = lite_session.GetOutputs();
176 if (tensors_list.empty()) {
177 MS_LOG(ERROR) << "Cannot find output tensors, get model output failed";
178 return RET_ERROR;
179 }
180 mindspore::tensor::MSTensor *tensor = nullptr;
181 int i = 1;
182 for (auto it = tensors_list.begin(); it != tensors_list.end(); ++it) {
183 tensor = lite_session.GetOutputByTensorName(it->first);
184 std::cout << "output is tensor " << it->first << "\n";
185 auto outputs = tensor->data();
186 size_t size;
187 std::string output_file = flags_->data_file_ + std::to_string(i) + ".bin";
188 auto bin_buf = std::unique_ptr<float[]>(ReadFileBuf(output_file.c_str(), &size));
189 if (bin_buf == nullptr) {
190 MS_LOG(ERROR) << "ReadFile return nullptr";
191 std::cout << "ReadFile return nullptr" << std::endl;
192 return RET_ERROR;
193 }
194 if (size != tensor->Size()) {
195 MS_LOG(ERROR) << "Output buffer and output file differ by size. Tensor size: " << tensor->Size()
196 << ", read size: " << size;
197 std::cout << "Output buffer and output file differ by size. Tensor size: " << tensor->Size()
198 << ", read size: " << size << std::endl;
199 return RET_ERROR;
200 }
201 float bias = CompareData<float>(bin_buf.get(), tensor->ElementsNum(), reinterpret_cast<float *>(outputs));
202 if (bias >= 0) {
203 total_bias += bias;
204 total_size++;
205 } else {
206 has_error = true;
207 break;
208 }
209 i++;
210 }
211
212 if (!has_error) {
213 float mean_bias;
214 if (total_size != 0) {
215 mean_bias = total_bias / total_size * 100;
216 } else {
217 mean_bias = 0;
218 }
219
220 std::cout << "Mean bias of all nodes/tensors: " << mean_bias << "%"
221 << " threshold is:" << this->flags_->accuracy_threshold_ << std::endl;
222 std::cout << "=======================================================" << std::endl << std::endl;
223
224 if (mean_bias > this->flags_->accuracy_threshold_) {
225 MS_LOG(INFO) << "Mean bias of all nodes/tensors is too big: " << mean_bias << "%";
226 std::cout << "Mean bias of all nodes/tensors is too big: " << mean_bias << "%" << std::endl;
227 return RET_TOO_BIG;
228 } else {
229 return RET_OK;
230 }
231 } else {
232 MS_LOG(ERROR) << "Error in CompareData";
233 std::cerr << "Error in CompareData" << std::endl;
234 std::cout << "=======================================================" << std::endl << std::endl;
235 return RET_ERROR;
236 }
237 }
238
MarkPerformance(const std::unique_ptr<session::LiteSession> & session)239 int NetTrain::MarkPerformance(const std::unique_ptr<session::LiteSession> &session) {
240 MS_LOG(INFO) << "Running train loops...";
241 std::cout << "Running train loops..." << std::endl;
242 uint64_t time_min = 0xFFFFFFFFFFFFFFFF;
243 uint64_t time_max = 0;
244 uint64_t time_avg = 0;
245
246 for (int i = 0; i < flags_->epochs_; i++) {
247 session->BindThread(true);
248 auto start = GetTimeUs();
249 auto status =
250 flags_->time_profiling_ ? session->RunGraph(before_call_back_, after_call_back_) : session->RunGraph();
251 if (status != 0) {
252 MS_LOG(ERROR) << "Inference error " << status;
253 std::cerr << "Inference error " << status;
254 return status;
255 }
256
257 auto end = GetTimeUs();
258 auto time = end - start;
259 time_min = std::min(time_min, time);
260 time_max = std::max(time_max, time);
261 time_avg += time;
262 session->BindThread(false);
263 }
264
265 if (flags_->time_profiling_) {
266 const std::vector<std::string> per_op_name = {"opName", "avg(ms)", "percent", "calledTimes", "opTotalTime"};
267 const std::vector<std::string> per_op_type = {"opType", "avg(ms)", "percent", "calledTimes", "opTotalTime"};
268 PrintResult(per_op_name, op_times_by_name_);
269 PrintResult(per_op_type, op_times_by_type_);
270 }
271
272 if (flags_->epochs_ > 0) {
273 time_avg /= flags_->epochs_;
274 MS_LOG(INFO) << "Model = " << flags_->model_file_.substr(flags_->model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
275 << ", NumThreads = " << flags_->num_threads_ << ", MinRunTime = " << time_min / 1000.0f
276 << ", MaxRuntime = " << time_max / 1000.0f << ", AvgRunTime = " << time_avg / 1000.0f;
277 printf("Model = %s, NumThreads = %d, MinRunTime = %f ms, MaxRuntime = %f ms, AvgRunTime = %f ms\n",
278 flags_->model_file_.substr(flags_->model_file_.find_last_of(DELIM_SLASH) + 1).c_str(), flags_->num_threads_,
279 time_min / 1000.0f, time_max / 1000.0f, time_avg / 1000.0f);
280 }
281 return RET_OK;
282 }
283
MarkAccuracy(const std::unique_ptr<session::LiteSession> & session,bool enforce_accuracy)284 int NetTrain::MarkAccuracy(const std::unique_ptr<session::LiteSession> &session, bool enforce_accuracy) {
285 MS_LOG(INFO) << "MarkAccuracy";
286 for (auto &msInput : session->GetInputs()) {
287 switch (msInput->data_type()) {
288 case TypeId::kNumberTypeFloat:
289 PrintInputData<float>(msInput);
290 break;
291 case TypeId::kNumberTypeFloat32:
292 PrintInputData<float>(msInput);
293 break;
294 case TypeId::kNumberTypeInt32:
295 PrintInputData<int>(msInput);
296 break;
297 default:
298 MS_LOG(ERROR) << "Datatype " << msInput->data_type() << " is not supported.";
299 return RET_ERROR;
300 }
301 }
302 auto status = session->RunGraph();
303 if (status != RET_OK) {
304 MS_LOG(ERROR) << "Inference error " << status;
305 std::cerr << "Inference error " << status << std::endl;
306 return status;
307 }
308
309 status = CompareOutput(*session);
310 if (status == RET_TOO_BIG && !enforce_accuracy) {
311 MS_LOG(INFO) << "Accuracy Error is big but not enforced";
312 std::cout << "Accuracy Error is big but not enforced" << std::endl;
313 return RET_OK;
314 }
315
316 if (status != RET_OK) {
317 MS_LOG(ERROR) << "Compare output error " << status;
318 std::cerr << "Compare output error " << status << std::endl;
319 return status;
320 }
321 return RET_OK;
322 }
323
FlagToBindMode(int flag)324 static CpuBindMode FlagToBindMode(int flag) {
325 if (flag == kCPUBindFlag2) {
326 return MID_CPU;
327 }
328 if (flag == kCPUBindFlag1) {
329 return HIGHER_CPU;
330 }
331 return NO_BIND;
332 }
333
CreateAndRunNetworkForTrain(const std::string & filename,const std::string & bb_filename,const Context & context,const TrainCfg & train_cfg,int epochs)334 std::unique_ptr<session::LiteSession> NetTrain::CreateAndRunNetworkForTrain(const std::string &filename,
335 const std::string &bb_filename,
336 const Context &context,
337 const TrainCfg &train_cfg, int epochs) {
338 std::unique_ptr<session::LiteSession> session = nullptr;
339 std::string model_name = filename.substr(filename.find_last_of(DELIM_SLASH) + 1);
340 if (!bb_filename.empty()) {
341 MS_LOG(INFO) << "CreateTransferSession from models files" << filename << " and " << bb_filename;
342 std::cout << "CreateTranferSession from model file " << filename << " and " << bb_filename << std::endl;
343 session = std::unique_ptr<session::LiteSession>(
344 session::TrainSession::CreateTransferSession(bb_filename, filename, &context, true, &train_cfg));
345 if (session == nullptr) {
346 MS_LOG(ERROR) << "RunNetTrain CreateTranferSession failed while running " << model_name.c_str();
347 std::cout << "RunNetTrain CreateTranferSession failed while running " << model_name.c_str() << std::endl;
348 return nullptr;
349 }
350 } else {
351 MS_LOG(INFO) << "CreateTrainSession from model file" << filename.c_str();
352 std::cout << "CreateTrainSession from model file " << filename.c_str() << std::endl;
353 std::cout << "Is raw mix precision model: " << train_cfg.mix_precision_cfg_.is_raw_mix_precision_ << std::endl;
354 session = std::unique_ptr<session::LiteSession>(
355 session::TrainSession::CreateTrainSession(filename, &context, true, &train_cfg));
356 if (session == nullptr) {
357 MS_LOG(ERROR) << "RunNetTrain CreateTrainSession failed while running " << model_name.c_str();
358 std::cout << "RunNetTrain CreateTrainSession failed while running " << model_name.c_str() << std::endl;
359 return nullptr;
360 }
361 }
362 if (epochs > 0) {
363 if (flags_->virtual_batch_) {
364 session->SetupVirtualBatch(epochs);
365 }
366 session->Train();
367 }
368 return session;
369 }
370
CreateAndRunNetworkForInference(const std::string & filename,const Context & context)371 std::unique_ptr<session::LiteSession> NetTrain::CreateAndRunNetworkForInference(const std::string &filename,
372 const Context &context) {
373 std::unique_ptr<session::LiteSession> session = nullptr;
374 std::string model_name = filename.substr(filename.find_last_of(DELIM_SLASH) + 1);
375 std::string filenamems = filename;
376 if (filenamems.substr(filenamems.find_last_of(".") + 1) != "ms") {
377 filenamems = filenamems + ".ms";
378 }
379
380 MS_LOG(INFO) << "start reading model file " << filenamems.c_str();
381 std::cout << "start reading model file " << filenamems.c_str() << std::endl;
382 auto *model = mindspore::lite::Model::Import(filenamems.c_str());
383 if (model == nullptr) {
384 MS_LOG(ERROR) << "create model for train session failed";
385 return nullptr;
386 }
387 session = std::unique_ptr<session::LiteSession>(session::LiteSession::CreateSession(&context));
388 if (session == nullptr) {
389 MS_LOG(ERROR) << "ExportedFile CreateSession failed while running " << model_name.c_str();
390 std::cout << "CreateSession failed while running " << model_name.c_str() << std::endl;
391 delete model;
392 return nullptr;
393 }
394 if (session->CompileGraph(model) != RET_OK) {
395 MS_LOG(ERROR) << "Cannot compile model";
396 delete model;
397 return nullptr;
398 }
399 delete model;
400 return session;
401 }
402
CreateAndRunNetwork(const std::string & filename,const std::string & bb_filename,int train_session,int epochs,bool check_accuracy)403 int NetTrain::CreateAndRunNetwork(const std::string &filename, const std::string &bb_filename, int train_session,
404 int epochs, bool check_accuracy) {
405 auto start_prepare_time = GetTimeUs();
406 Context context;
407 context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = FlagToBindMode(flags_->cpu_bind_mode_);
408 context.device_list_[0].device_info_.cpu_device_info_.enable_float16_ = flags_->enable_fp16_;
409 context.device_list_[0].device_type_ = mindspore::lite::DT_CPU;
410 context.thread_num_ = flags_->num_threads_;
411
412 TrainCfg train_cfg;
413 if (flags_->loss_name_ != "") {
414 train_cfg.loss_name_ = flags_->loss_name_;
415 }
416 train_cfg.mix_precision_cfg_.is_raw_mix_precision_ = flags_->is_raw_mix_precision_;
417 std::unique_ptr<session::LiteSession> session;
418 if (train_session) {
419 session = CreateAndRunNetworkForTrain(filename, bb_filename, context, train_cfg, epochs);
420 if (session == nullptr) {
421 MS_LOG(ERROR) << "CreateAndRunNetworkForInference failed.";
422 return RET_ERROR;
423 }
424 } else {
425 session = CreateAndRunNetworkForInference(filename, context);
426 if (session == nullptr) {
427 MS_LOG(ERROR) << "CreateAndRunNetworkForInference failed.";
428 return RET_ERROR;
429 }
430 }
431
432 if (!flags_->resize_dims_.empty()) {
433 auto ret = session->Resize(session->GetInputs(), flags_->resize_dims_);
434 if (ret != RET_OK) {
435 MS_LOG(ERROR) << "Input tensor resize failed.";
436 std::cout << "Input tensor resize failed.";
437 return ret;
438 }
439 }
440
441 auto end_prepare_time = GetTimeUs();
442 MS_LOG(INFO) << "PrepareTime = " << ((end_prepare_time - start_prepare_time) / kTHOUSAND) << " ms";
443 std::cout << "PrepareTime = " << ((end_prepare_time - start_prepare_time) / kTHOUSAND) << " ms" << std::endl;
444 // Load input
445 MS_LOG(INFO) << "Load input data";
446 auto ms_inputs = session->GetInputs();
447 auto status = LoadInput(&ms_inputs);
448 if (status != RET_OK) {
449 MS_LOG(ERROR) << "Load input data error";
450 return status;
451 }
452
453 if ((epochs > 0) && train_session) {
454 status = MarkPerformance(session);
455 if (status != RET_OK) {
456 MS_LOG(ERROR) << "Run MarkPerformance error: " << status;
457 std::cout << "Run MarkPerformance error: " << status << std::endl;
458 return status;
459 }
460 SaveModels(session); // save file if flags are on
461 }
462 if (!flags_->data_file_.empty()) {
463 session->Eval();
464
465 status = MarkAccuracy(session, check_accuracy);
466 if (status != RET_OK) {
467 MS_LOG(ERROR) << "Run MarkAccuracy error: " << status;
468 std::cout << "Run MarkAccuracy error: " << status << std::endl;
469 return status;
470 }
471 }
472 return RET_OK;
473 }
474
RunNetTrain()475 int NetTrain::RunNetTrain() {
476 auto status = CreateAndRunNetwork(flags_->model_file_, flags_->bb_model_file_, true, flags_->epochs_);
477 if (status != RET_OK) {
478 MS_LOG(ERROR) << "CreateAndRunNetwork failed for model " << flags_->model_file_ << ". Status is " << status;
479 std::cout << "CreateAndRunNetwork failed for model " << flags_->model_file_ << ". Status is " << status
480 << std::endl;
481 return status;
482 }
483
484 status = CheckExecutionOfSavedModels(); // re-initialize sessions according to flags
485 if (status != RET_OK) {
486 MS_LOG(ERROR) << "Run CheckExecute error: " << status;
487 std::cout << "Run CheckExecute error: " << status << std::endl;
488 return status;
489 }
490 return RET_OK;
491 }
492
SaveModels(const std::unique_ptr<session::LiteSession> & session)493 int NetTrain::SaveModels(const std::unique_ptr<session::LiteSession> &session) {
494 if (!flags_->export_file_.empty()) {
495 if (flags_->bb_model_file_.empty()) {
496 auto status = session->Export(flags_->export_file_ + "_qt", lite::MT_TRAIN, lite::QT_WEIGHT);
497 if (status != RET_OK) {
498 MS_LOG(ERROR) << "Export quantized model error " << flags_->export_file_ + "_qt";
499 std::cout << "Export quantized model error " << flags_->export_file_ + "_qt" << std::endl;
500 return RET_ERROR;
501 }
502 }
503 auto status = session->Export(flags_->export_file_, lite::MT_TRAIN, lite::QT_NONE);
504 if (status != RET_OK) {
505 MS_LOG(ERROR) << "Export non quantized model error " << flags_->export_file_;
506 std::cout << "Export non quantized model error " << flags_->export_file_ << std::endl;
507 return RET_ERROR;
508 }
509 }
510 if (!flags_->inference_file_.empty()) {
511 auto status = session->Export(flags_->inference_file_ + "_qt", lite::MT_INFERENCE, lite::QT_WEIGHT);
512 if (status != RET_OK) {
513 MS_LOG(ERROR) << "Export quantized inference model error " << flags_->inference_file_ + "_qt";
514 std::cout << "Export quantized inference model error " << flags_->inference_file_ + "_qt" << std::endl;
515 return RET_ERROR;
516 }
517
518 auto tick = GetTimeUs();
519 status = session->Export(flags_->inference_file_, lite::MT_INFERENCE, lite::QT_NONE);
520 if (status != RET_OK) {
521 MS_LOG(ERROR) << "Export non quantized inference model error " << flags_->inference_file_ + "_qt";
522 std::cout << "Export non quantized inference model error " << flags_->inference_file_ + "_qt" << std::endl;
523 return status;
524 }
525 std::cout << "ExportInference() execution time is " << GetTimeUs() - tick << "us\n";
526 }
527 return RET_OK;
528 }
529
CheckExecutionOfSavedModels()530 int NetTrain::CheckExecutionOfSavedModels() {
531 int status = RET_OK;
532 if (!flags_->export_file_.empty()) {
533 status = NetTrain::CreateAndRunNetwork(flags_->export_file_, flags_->bb_model_file_, true, 0);
534 if (status != RET_OK) {
535 MS_LOG(ERROR) << "Run Exported model " << flags_->export_file_ << " error: " << status;
536 std::cout << "Run Exported model " << flags_->export_file_ << " error: " << status << std::endl;
537 return status;
538 }
539 if (flags_->bb_model_file_.empty()) {
540 status = NetTrain::CreateAndRunNetwork(flags_->export_file_ + "_qt", "", true, 0, false);
541 if (status != RET_OK) {
542 MS_LOG(ERROR) << "Run Exported model " << flags_->export_file_ << "_qt.ms error: " << status;
543 std::cout << "Run Exported model " << flags_->export_file_ << "_qt.ms error: " << status << std::endl;
544 return status;
545 }
546 }
547 }
548 if (!flags_->inference_file_.empty()) {
549 status = NetTrain::CreateAndRunNetwork(flags_->inference_file_, "", false, 0);
550 if (status != RET_OK) {
551 MS_LOG(ERROR) << "Running saved model " << flags_->inference_file_ << ".ms error: " << status;
552 std::cout << "Running saved model " << flags_->inference_file_ << ".ms error: " << status << std::endl;
553 return status;
554 }
555 status = NetTrain::CreateAndRunNetwork(flags_->inference_file_ + "_qt", "", false, 0, false);
556 if (status != RET_OK) {
557 MS_LOG(ERROR) << "Running saved model " << flags_->inference_file_ << "_qt.ms error: " << status;
558 std::cout << "Running saved model " << flags_->inference_file_ << "_qt.ms error: " << status << std::endl;
559 return status;
560 }
561 }
562 return status;
563 }
564
CheckSum(mindspore::tensor::MSTensor * tensor,std::string node_type,int id,std::string in_out)565 void NetTrain::CheckSum(mindspore::tensor::MSTensor *tensor, std::string node_type, int id, std::string in_out) {
566 int tensor_size = tensor->ElementsNum();
567 void *data = tensor->MutableData();
568 TypeId type = tensor->data_type();
569 std::cout << node_type << " " << in_out << id << " shape=" << tensor->shape() << " sum=";
570 switch (type) {
571 case kNumberTypeFloat32:
572 TensorNan(reinterpret_cast<float *>(data), tensor_size);
573 std::cout << TensorSum<float>(data, tensor_size) << std::endl;
574 break;
575 case kNumberTypeInt32:
576 std::cout << TensorSum<int>(data, tensor_size) << std::endl;
577 break;
578 #ifdef ENABLE_FP16
579 case kNumberTypeFloat16:
580 std::cout << TensorSum<float16_t>(data, tensor_size) << std::endl;
581 break;
582 #endif
583 default:
584 std::cout << "unsupported type:" << type << std::endl;
585 break;
586 }
587 }
588
InitCallbackParameter()589 int NetTrain::InitCallbackParameter() {
590 // before callback
591 before_call_back_ = [&](const std::vector<mindspore::tensor::MSTensor *> &before_inputs,
592 const std::vector<mindspore::tensor::MSTensor *> &before_outputs,
593 const mindspore::CallBackParam &callParam) {
594 if (before_inputs.empty()) {
595 MS_LOG(INFO) << "The num of beforeInputs is empty";
596 }
597 if (before_outputs.empty()) {
598 MS_LOG(INFO) << "The num of beforeOutputs is empty";
599 }
600 if (op_times_by_type_.find(callParam.node_type) == op_times_by_type_.end()) {
601 op_times_by_type_.insert(std::make_pair(callParam.node_type, std::make_pair(0, 0.0f)));
602 }
603 if (op_times_by_name_.find(callParam.node_name) == op_times_by_name_.end()) {
604 op_times_by_name_.insert(std::make_pair(callParam.node_name, std::make_pair(0, 0.0f)));
605 }
606 op_call_times_total_++;
607 op_begin_ = GetTimeUs();
608 if ((callParam.node_type == "Adam") || (callParam.node_type == "Assign") || callParam.node_type == "SGD") {
609 for (auto tensor : before_outputs) {
610 std::fill(reinterpret_cast<int8_t *>(tensor->MutableData()),
611 reinterpret_cast<int8_t *>(tensor->MutableData()) + tensor->Size(), 0);
612 }
613 }
614 return true;
615 };
616
617 // after callback
618 after_call_back_ = [&](const std::vector<mindspore::tensor::MSTensor *> &after_inputs,
619 const std::vector<mindspore::tensor::MSTensor *> &after_outputs,
620 const mindspore::CallBackParam &call_param) {
621 uint64_t opEnd = GetTimeUs();
622 if (after_inputs.empty()) {
623 MS_LOG(INFO) << "The num of after inputs is empty";
624 }
625 if (after_outputs.empty()) {
626 MS_LOG(INFO) << "The num of after outputs is empty";
627 }
628 float cost = static_cast<float>(opEnd - op_begin_) / 1000.0f;
629 op_cost_total_ += cost;
630 op_times_by_type_[call_param.node_type].first++;
631 op_times_by_type_[call_param.node_type].second += cost;
632 op_times_by_name_[call_param.node_name].first++;
633 op_times_by_name_[call_param.node_name].second += cost;
634 if (flags_->layer_checksum_) {
635 for (size_t i = 0; i < after_inputs.size(); i++) {
636 CheckSum(after_inputs.at(i), call_param.node_type, i, "in");
637 }
638 for (size_t i = 0; i < after_outputs.size(); i++) {
639 CheckSum(after_outputs.at(i), call_param.node_type, i, "out");
640 }
641 std::cout << std::endl;
642 }
643 return true;
644 };
645 return RET_OK;
646 }
647
InitResizeDimsList()648 void NetTrainFlags::InitResizeDimsList() {
649 std::string content = this->resize_dims_in_;
650 std::vector<int> shape;
651 auto shape_strs = StrSplit(content, std::string(DELIM_COLON));
652 for (const auto &shape_str : shape_strs) {
653 shape.clear();
654 auto dim_strs = StrSplit(shape_str, std::string(DELIM_COMMA));
655 std::cout << "Resize Dims: ";
656 for (const auto &dim_str : dim_strs) {
657 std::cout << dim_str << " ";
658 shape.emplace_back(static_cast<int>(std::stoi(dim_str)));
659 }
660 std::cout << std::endl;
661 this->resize_dims_.emplace_back(shape);
662 }
663 }
664
Init()665 int NetTrain::Init() {
666 if (this->flags_ == nullptr) {
667 return 1;
668 }
669 MS_LOG(INFO) << "ModelPath = " << this->flags_->model_file_;
670 MS_LOG(INFO) << "InDataPath = " << this->flags_->in_data_file_;
671 MS_LOG(INFO) << "InDataType = " << this->flags_->in_data_type_in_;
672 MS_LOG(INFO) << "Epochs = " << this->flags_->epochs_;
673 MS_LOG(INFO) << "AccuracyThreshold = " << this->flags_->accuracy_threshold_;
674 MS_LOG(INFO) << "WarmUpLoopCount = " << this->flags_->warm_up_loop_count_;
675 MS_LOG(INFO) << "NumThreads = " << this->flags_->num_threads_;
676 MS_LOG(INFO) << "expectedDataFile = " << this->flags_->data_file_;
677 MS_LOG(INFO) << "exportDataFile = " << this->flags_->export_file_;
678 MS_LOG(INFO) << "enableFp16 = " << this->flags_->enable_fp16_;
679 MS_LOG(INFO) << "virtualBatch = " << this->flags_->virtual_batch_;
680
681 if (this->flags_->epochs_ < 0) {
682 MS_LOG(ERROR) << "epochs:" << this->flags_->epochs_ << " must be equal/greater than 0";
683 std::cerr << "epochs:" << this->flags_->epochs_ << " must be equal/greater than 0" << std::endl;
684 return RET_ERROR;
685 }
686
687 if (this->flags_->num_threads_ < 1) {
688 MS_LOG(ERROR) << "numThreads:" << this->flags_->num_threads_ << " must be greater than 0";
689 std::cerr << "numThreads:" << this->flags_->num_threads_ << " must be greater than 0" << std::endl;
690 return RET_ERROR;
691 }
692
693 this->flags_->in_data_type_ = this->flags_->in_data_type_in_ == "img" ? kImage : kBinary;
694
695 if (flags_->in_data_file_.empty() && !flags_->data_file_.empty()) {
696 MS_LOG(ERROR) << "expectedDataFile not supported in case that inDataFile is not provided";
697 std::cerr << "expectedDataFile is not supported in case that inDataFile is not provided" << std::endl;
698 return RET_ERROR;
699 }
700
701 if (flags_->in_data_file_.empty() && !flags_->export_file_.empty()) {
702 MS_LOG(ERROR) << "exportDataFile not supported in case that inDataFile is not provided";
703 std::cerr << "exportDataFile is not supported in case that inDataFile is not provided" << std::endl;
704 return RET_ERROR;
705 }
706
707 if (flags_->model_file_.empty()) {
708 MS_LOG(ERROR) << "modelPath is required";
709 std::cerr << "modelPath is required" << std::endl;
710 return 1;
711 }
712
713 if (flags_->time_profiling_) {
714 auto status = InitCallbackParameter();
715 if (status != RET_OK) {
716 MS_LOG(ERROR) << "Init callback Parameter failed.";
717 std::cerr << "Init callback Parameter failed." << std::endl;
718 return RET_ERROR;
719 }
720 }
721 flags_->InitResizeDimsList();
722 if (!flags_->resize_dims_.empty() && !flags_->input_data_list_.empty() &&
723 flags_->resize_dims_.size() != flags_->input_data_list_.size()) {
724 MS_LOG(ERROR) << "Size of input resizeDims should be equal to size of input inDataPath";
725 std::cerr << "Size of input resizeDims should be equal to size of input inDataPath" << std::endl;
726 return RET_ERROR;
727 }
728 return RET_OK;
729 }
730
731 namespace {
732 constexpr int kNumToPrint = 5;
733 }
734
PrintResult(const std::vector<std::string> & title,const std::map<std::string,std::pair<int,float>> & result)735 int NetTrain::PrintResult(const std::vector<std::string> &title,
736 const std::map<std::string, std::pair<int, float>> &result) {
737 std::vector<size_t> columnLenMax(kFieldsToPrint);
738 std::vector<std::vector<std::string>> rows;
739
740 for (auto &iter : result) {
741 std::string stringBuf[kFieldsToPrint];
742 std::vector<std::string> columns;
743 size_t len;
744
745 len = iter.first.size();
746 if (len > columnLenMax.at(kField0)) {
747 columnLenMax.at(kField0) = len + kPrintOffset;
748 }
749 columns.push_back(iter.first);
750
751 stringBuf[kField1] = to_string(iter.second.second / flags_->epochs_);
752 len = stringBuf[kField1].length();
753 if (len > columnLenMax.at(kField1)) {
754 columnLenMax.at(kField1) = len + kPrintOffset;
755 }
756 columns.emplace_back(stringBuf[kField1]);
757
758 stringBuf[kField2] = to_string(iter.second.second / op_cost_total_);
759 len = stringBuf[kField2].length();
760 if (len > columnLenMax.at(kField2)) {
761 columnLenMax.at(kField2) = len + kPrintOffset;
762 }
763 columns.emplace_back(stringBuf[kField2]);
764
765 stringBuf[kField3] = to_string(iter.second.first);
766 len = stringBuf[kField3].length();
767 if (len > columnLenMax.at(kField3)) {
768 columnLenMax.at(kField3) = len + kPrintOffset;
769 }
770 columns.emplace_back(stringBuf[kField3]);
771
772 stringBuf[kField4] = to_string(iter.second.second);
773 len = stringBuf[kField4].length();
774 if (len > columnLenMax.at(kField4)) {
775 columnLenMax.at(kField4) = len + kPrintOffset;
776 }
777 columns.emplace_back(stringBuf[kField4]);
778
779 rows.push_back(columns);
780 }
781
782 printf("-------------------------------------------------------------------------\n");
783 for (int i = 0; i < kNumToPrint; i++) {
784 auto printBuf = title[i];
785 if (printBuf.size() > columnLenMax.at(i)) {
786 columnLenMax.at(i) = printBuf.size();
787 }
788 printBuf.resize(columnLenMax.at(i), ' ');
789 printf("%s\t", printBuf.c_str());
790 }
791 printf("\n");
792 for (size_t i = 0; i < rows.size(); i++) {
793 for (int j = 0; j < kNumToPrint; j++) {
794 auto printBuf = rows[i][j];
795 printBuf.resize(columnLenMax.at(j), ' ');
796 printf("%s\t", printBuf.c_str());
797 }
798 printf("\n");
799 }
800 return RET_OK;
801 }
802
RunNetTrain(int argc,const char ** argv)803 int RunNetTrain(int argc, const char **argv) {
804 NetTrainFlags flags;
805 Option<std::string> err = flags.ParseFlags(argc, argv);
806
807 if (err.IsSome()) {
808 std::cerr << err.Get() << std::endl;
809 std::cerr << flags.Usage() << std::endl;
810 return RET_ERROR;
811 }
812
813 if (flags.help) {
814 std::cerr << flags.Usage() << std::endl;
815 return RET_OK;
816 }
817
818 NetTrain net_trainer(&flags);
819 auto status = net_trainer.Init();
820 if (status != RET_OK) {
821 MS_LOG(ERROR) << "NetTrain init Error : " << status;
822 std::cerr << "NetTrain init Error : " << status << std::endl;
823 return RET_ERROR;
824 }
825
826 status = net_trainer.RunNetTrain();
827 if (status != RET_OK) {
828 MS_LOG(ERROR) << "Run NetTrain "
829 << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
830 << " Failed : " << status;
831 std::cerr << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
832 << " Failed : " << status << std::endl;
833 return RET_ERROR;
834 }
835
836 MS_LOG(INFO) << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
837 << " Success.";
838 std::cout << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
839 << " Success." << std::endl;
840 return RET_OK;
841 }
842 } // namespace lite
843 } // namespace mindspore
844