• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/benchmark_train/net_train.h"
18 #define __STDC_FORMAT_MACROS
19 #include <cinttypes>
20 #undef __STDC_FORMAT_MACROS
21 #include <algorithm>
22 #include <cstring>
23 #include <utility>
24 #ifdef ENABLE_NEON
25 #include <arm_neon.h>
26 #endif
27 #include "src/common/common.h"
28 #include "include/ms_tensor.h"
29 #include "include/context.h"
30 #include "include/version.h"
31 #include "include/model.h"
32 #include "include/train/train_cfg.h"
33 #include "include/train/train_session.h"
34 
35 namespace mindspore {
36 namespace lite {
37 static const char *DELIM_SLASH = "/";
38 constexpr const char *DELIM_COLON = ":";
39 constexpr const char *DELIM_COMMA = ",";
40 constexpr int RET_TOO_BIG = -9;
41 constexpr int kField0 = 0;
42 constexpr int kField1 = 1;
43 constexpr int kField2 = 2;
44 constexpr int kField3 = 3;
45 constexpr int kField4 = 4;
46 constexpr int kFieldsToPrint = 5;
47 constexpr int kPrintOffset = 4;
48 constexpr int kCPUBindFlag2 = 2;
49 constexpr int kCPUBindFlag1 = 1;
50 static const int kTHOUSAND = 1000;
51 
52 namespace {
ReadFileBuf(const char * file,size_t * size)53 float *ReadFileBuf(const char *file, size_t *size) {
54   if (file == nullptr) {
55     MS_LOG(ERROR) << "file is nullptr";
56     return nullptr;
57   }
58   MS_ASSERT(size != nullptr);
59   std::string real_path = RealPath(file);
60   std::ifstream ifs(real_path);
61   if (!ifs.good()) {
62     MS_LOG(ERROR) << "file: " << real_path << " is not exist";
63     return nullptr;
64   }
65 
66   if (!ifs.is_open()) {
67     MS_LOG(ERROR) << "file: " << real_path << " open failed";
68     return nullptr;
69   }
70 
71   ifs.seekg(0, std::ios::end);
72   *size = ifs.tellg();
73   std::unique_ptr<float[]> buf = std::make_unique<float[]>(*size / sizeof(float) + 1);
74   if (buf == nullptr) {
75     MS_LOG(ERROR) << "malloc buf failed, file: " << real_path;
76     ifs.close();
77     return nullptr;
78   }
79 
80   ifs.seekg(0, std::ios::beg);
81   ifs.read(reinterpret_cast<char *>(buf.get()), *size);
82   ifs.close();
83 
84   return buf.release();
85 }
86 }  // namespace
87 
GenerateRandomData(size_t size,void * data)88 int NetTrain::GenerateRandomData(size_t size, void *data) {
89   MS_ASSERT(data != nullptr);
90   char *casted_data = static_cast<char *>(data);
91   for (size_t i = 0; i < size; i++) {
92     casted_data[i] = static_cast<char>(i);
93   }
94   return RET_OK;
95 }
96 
GenerateInputData(std::vector<mindspore::tensor::MSTensor * > * ms_inputs)97 int NetTrain::GenerateInputData(std::vector<mindspore::tensor::MSTensor *> *ms_inputs) {
98   for (auto tensor : *ms_inputs) {
99     MS_ASSERT(tensor != nullptr);
100     auto input_data = tensor->MutableData();
101     if (input_data == nullptr) {
102       MS_LOG(ERROR) << "MallocData for inTensor failed";
103       return RET_ERROR;
104     }
105     auto tensor_byte_size = tensor->Size();
106     auto status = GenerateRandomData(tensor_byte_size, input_data);
107     if (status != RET_OK) {
108       std::cerr << "GenerateRandomData for inTensor failed: " << status << std::endl;
109       MS_LOG(ERROR) << "GenerateRandomData for inTensor failed: " << status;
110       return status;
111     }
112   }
113   return RET_OK;
114 }
115 
LoadInput(std::vector<mindspore::tensor::MSTensor * > * ms_inputs)116 int NetTrain::LoadInput(std::vector<mindspore::tensor::MSTensor *> *ms_inputs) {
117   if (flags_->in_data_file_.empty()) {
118     auto status = GenerateInputData(ms_inputs);
119     if (status != RET_OK) {
120       std::cerr << "Generate input data error " << status << std::endl;
121       MS_LOG(ERROR) << "Generate input data error " << status;
122       return status;
123     }
124   } else {
125     auto status = ReadInputFile(ms_inputs);
126     if (status != RET_OK) {
127       std::cerr << "Read Input File error, " << status << std::endl;
128       MS_LOG(ERROR) << "Read Input File error, " << status;
129       return status;
130     }
131   }
132   return RET_OK;
133 }
134 
ReadInputFile(std::vector<mindspore::tensor::MSTensor * > * ms_inputs)135 int NetTrain::ReadInputFile(std::vector<mindspore::tensor::MSTensor *> *ms_inputs) {
136   if (ms_inputs->empty()) {
137     return RET_OK;
138   }
139 
140   if (this->flags_->in_data_type_ == kImage) {
141     MS_LOG(ERROR) << "Not supported image input";
142     return RET_ERROR;
143   } else {
144     for (size_t i = 0; i < ms_inputs->size(); i++) {
145       auto cur_tensor = ms_inputs->at(i);
146       MS_ASSERT(cur_tensor != nullptr);
147       size_t size;
148       std::string file_name = flags_->in_data_file_ + std::to_string(i + 1) + ".bin";
149       auto bin_buf = ReadFile(file_name.c_str(), &size);
150       if (bin_buf == nullptr) {
151         MS_LOG(ERROR) << "ReadFile return nullptr";
152         return RET_ERROR;
153       }
154       auto tensor_data_size = cur_tensor->Size();
155       if (size != tensor_data_size) {
156         std::cerr << "Input binary file size error, required: " << tensor_data_size << ", in fact: " << size
157                   << std::endl;
158         MS_LOG(ERROR) << "Input binary file size error, required: " << tensor_data_size << ", in fact: " << size;
159         delete bin_buf;
160         return RET_ERROR;
161       }
162       auto input_data = cur_tensor->MutableData();
163       memcpy(input_data, bin_buf, tensor_data_size);
164       delete[](bin_buf);
165     }
166   }
167   return RET_OK;
168 }
169 
CompareOutput(const session::LiteSession & lite_session)170 int NetTrain::CompareOutput(const session::LiteSession &lite_session) {
171   std::cout << "================ Comparing Forward Output data ================" << std::endl;
172   float total_bias = 0;
173   int total_size = 0;
174   bool has_error = false;
175   auto tensors_list = lite_session.GetOutputs();
176   if (tensors_list.empty()) {
177     MS_LOG(ERROR) << "Cannot find output tensors, get model output failed";
178     return RET_ERROR;
179   }
180   mindspore::tensor::MSTensor *tensor = nullptr;
181   int i = 1;
182   for (auto it = tensors_list.begin(); it != tensors_list.end(); ++it) {
183     tensor = lite_session.GetOutputByTensorName(it->first);
184     std::cout << "output is tensor " << it->first << "\n";
185     auto outputs = tensor->data();
186     size_t size;
187     std::string output_file = flags_->data_file_ + std::to_string(i) + ".bin";
188     auto bin_buf = std::unique_ptr<float[]>(ReadFileBuf(output_file.c_str(), &size));
189     if (bin_buf == nullptr) {
190       MS_LOG(ERROR) << "ReadFile return nullptr";
191       std::cout << "ReadFile return nullptr" << std::endl;
192       return RET_ERROR;
193     }
194     if (size != tensor->Size()) {
195       MS_LOG(ERROR) << "Output buffer and output file differ by size. Tensor size: " << tensor->Size()
196                     << ", read size: " << size;
197       std::cout << "Output buffer and output file differ by size. Tensor size: " << tensor->Size()
198                 << ", read size: " << size << std::endl;
199       return RET_ERROR;
200     }
201     float bias = CompareData<float>(bin_buf.get(), tensor->ElementsNum(), reinterpret_cast<float *>(outputs));
202     if (bias >= 0) {
203       total_bias += bias;
204       total_size++;
205     } else {
206       has_error = true;
207       break;
208     }
209     i++;
210   }
211 
212   if (!has_error) {
213     float mean_bias;
214     if (total_size != 0) {
215       mean_bias = total_bias / total_size * 100;
216     } else {
217       mean_bias = 0;
218     }
219 
220     std::cout << "Mean bias of all nodes/tensors: " << mean_bias << "%"
221               << " threshold is:" << this->flags_->accuracy_threshold_ << std::endl;
222     std::cout << "=======================================================" << std::endl << std::endl;
223 
224     if (mean_bias > this->flags_->accuracy_threshold_) {
225       MS_LOG(INFO) << "Mean bias of all nodes/tensors is too big: " << mean_bias << "%";
226       std::cout << "Mean bias of all nodes/tensors is too big: " << mean_bias << "%" << std::endl;
227       return RET_TOO_BIG;
228     } else {
229       return RET_OK;
230     }
231   } else {
232     MS_LOG(ERROR) << "Error in CompareData";
233     std::cerr << "Error in CompareData" << std::endl;
234     std::cout << "=======================================================" << std::endl << std::endl;
235     return RET_ERROR;
236   }
237 }
238 
MarkPerformance(const std::unique_ptr<session::LiteSession> & session)239 int NetTrain::MarkPerformance(const std::unique_ptr<session::LiteSession> &session) {
240   MS_LOG(INFO) << "Running train loops...";
241   std::cout << "Running train loops..." << std::endl;
242   uint64_t time_min = 0xFFFFFFFFFFFFFFFF;
243   uint64_t time_max = 0;
244   uint64_t time_avg = 0;
245 
246   for (int i = 0; i < flags_->epochs_; i++) {
247     session->BindThread(true);
248     auto start = GetTimeUs();
249     auto status =
250       flags_->time_profiling_ ? session->RunGraph(before_call_back_, after_call_back_) : session->RunGraph();
251     if (status != 0) {
252       MS_LOG(ERROR) << "Inference error " << status;
253       std::cerr << "Inference error " << status;
254       return status;
255     }
256 
257     auto end = GetTimeUs();
258     auto time = end - start;
259     time_min = std::min(time_min, time);
260     time_max = std::max(time_max, time);
261     time_avg += time;
262     session->BindThread(false);
263   }
264 
265   if (flags_->time_profiling_) {
266     const std::vector<std::string> per_op_name = {"opName", "avg(ms)", "percent", "calledTimes", "opTotalTime"};
267     const std::vector<std::string> per_op_type = {"opType", "avg(ms)", "percent", "calledTimes", "opTotalTime"};
268     PrintResult(per_op_name, op_times_by_name_);
269     PrintResult(per_op_type, op_times_by_type_);
270   }
271 
272   if (flags_->epochs_ > 0) {
273     time_avg /= flags_->epochs_;
274     MS_LOG(INFO) << "Model = " << flags_->model_file_.substr(flags_->model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
275                  << ", NumThreads = " << flags_->num_threads_ << ", MinRunTime = " << time_min / 1000.0f
276                  << ", MaxRuntime = " << time_max / 1000.0f << ", AvgRunTime = " << time_avg / 1000.0f;
277     printf("Model = %s, NumThreads = %d, MinRunTime = %f ms, MaxRuntime = %f ms, AvgRunTime = %f ms\n",
278            flags_->model_file_.substr(flags_->model_file_.find_last_of(DELIM_SLASH) + 1).c_str(), flags_->num_threads_,
279            time_min / 1000.0f, time_max / 1000.0f, time_avg / 1000.0f);
280   }
281   return RET_OK;
282 }
283 
MarkAccuracy(const std::unique_ptr<session::LiteSession> & session,bool enforce_accuracy)284 int NetTrain::MarkAccuracy(const std::unique_ptr<session::LiteSession> &session, bool enforce_accuracy) {
285   MS_LOG(INFO) << "MarkAccuracy";
286   for (auto &msInput : session->GetInputs()) {
287     switch (msInput->data_type()) {
288       case TypeId::kNumberTypeFloat:
289         PrintInputData<float>(msInput);
290         break;
291       case TypeId::kNumberTypeFloat32:
292         PrintInputData<float>(msInput);
293         break;
294       case TypeId::kNumberTypeInt32:
295         PrintInputData<int>(msInput);
296         break;
297       default:
298         MS_LOG(ERROR) << "Datatype " << msInput->data_type() << " is not supported.";
299         return RET_ERROR;
300     }
301   }
302   auto status = session->RunGraph();
303   if (status != RET_OK) {
304     MS_LOG(ERROR) << "Inference error " << status;
305     std::cerr << "Inference error " << status << std::endl;
306     return status;
307   }
308 
309   status = CompareOutput(*session);
310   if (status == RET_TOO_BIG && !enforce_accuracy) {
311     MS_LOG(INFO) << "Accuracy Error is big but not enforced";
312     std::cout << "Accuracy Error is big but not enforced" << std::endl;
313     return RET_OK;
314   }
315 
316   if (status != RET_OK) {
317     MS_LOG(ERROR) << "Compare output error " << status;
318     std::cerr << "Compare output error " << status << std::endl;
319     return status;
320   }
321   return RET_OK;
322 }
323 
FlagToBindMode(int flag)324 static CpuBindMode FlagToBindMode(int flag) {
325   if (flag == kCPUBindFlag2) {
326     return MID_CPU;
327   }
328   if (flag == kCPUBindFlag1) {
329     return HIGHER_CPU;
330   }
331   return NO_BIND;
332 }
333 
CreateAndRunNetworkForTrain(const std::string & filename,const std::string & bb_filename,const Context & context,const TrainCfg & train_cfg,int epochs)334 std::unique_ptr<session::LiteSession> NetTrain::CreateAndRunNetworkForTrain(const std::string &filename,
335                                                                             const std::string &bb_filename,
336                                                                             const Context &context,
337                                                                             const TrainCfg &train_cfg, int epochs) {
338   std::unique_ptr<session::LiteSession> session = nullptr;
339   std::string model_name = filename.substr(filename.find_last_of(DELIM_SLASH) + 1);
340   if (!bb_filename.empty()) {
341     MS_LOG(INFO) << "CreateTransferSession from models files" << filename << " and " << bb_filename;
342     std::cout << "CreateTranferSession from model file " << filename << " and " << bb_filename << std::endl;
343     session = std::unique_ptr<session::LiteSession>(
344       session::TrainSession::CreateTransferSession(bb_filename, filename, &context, true, &train_cfg));
345     if (session == nullptr) {
346       MS_LOG(ERROR) << "RunNetTrain CreateTranferSession failed while running " << model_name.c_str();
347       std::cout << "RunNetTrain CreateTranferSession failed while running " << model_name.c_str() << std::endl;
348       return nullptr;
349     }
350   } else {
351     MS_LOG(INFO) << "CreateTrainSession from model file" << filename.c_str();
352     std::cout << "CreateTrainSession from model file " << filename.c_str() << std::endl;
353     std::cout << "Is raw mix precision model: " << train_cfg.mix_precision_cfg_.is_raw_mix_precision_ << std::endl;
354     session = std::unique_ptr<session::LiteSession>(
355       session::TrainSession::CreateTrainSession(filename, &context, true, &train_cfg));
356     if (session == nullptr) {
357       MS_LOG(ERROR) << "RunNetTrain CreateTrainSession failed while running " << model_name.c_str();
358       std::cout << "RunNetTrain CreateTrainSession failed while running " << model_name.c_str() << std::endl;
359       return nullptr;
360     }
361   }
362   if (epochs > 0) {
363     if (flags_->virtual_batch_) {
364       session->SetupVirtualBatch(epochs);
365     }
366     session->Train();
367   }
368   return session;
369 }
370 
CreateAndRunNetworkForInference(const std::string & filename,const Context & context)371 std::unique_ptr<session::LiteSession> NetTrain::CreateAndRunNetworkForInference(const std::string &filename,
372                                                                                 const Context &context) {
373   std::unique_ptr<session::LiteSession> session = nullptr;
374   std::string model_name = filename.substr(filename.find_last_of(DELIM_SLASH) + 1);
375   std::string filenamems = filename;
376   if (filenamems.substr(filenamems.find_last_of(".") + 1) != "ms") {
377     filenamems = filenamems + ".ms";
378   }
379 
380   MS_LOG(INFO) << "start reading model file " << filenamems.c_str();
381   std::cout << "start reading model file " << filenamems.c_str() << std::endl;
382   auto *model = mindspore::lite::Model::Import(filenamems.c_str());
383   if (model == nullptr) {
384     MS_LOG(ERROR) << "create model for train session failed";
385     return nullptr;
386   }
387   session = std::unique_ptr<session::LiteSession>(session::LiteSession::CreateSession(&context));
388   if (session == nullptr) {
389     MS_LOG(ERROR) << "ExportedFile CreateSession failed while running " << model_name.c_str();
390     std::cout << "CreateSession failed while running " << model_name.c_str() << std::endl;
391     delete model;
392     return nullptr;
393   }
394   if (session->CompileGraph(model) != RET_OK) {
395     MS_LOG(ERROR) << "Cannot compile model";
396     delete model;
397     return nullptr;
398   }
399   delete model;
400   return session;
401 }
402 
CreateAndRunNetwork(const std::string & filename,const std::string & bb_filename,int train_session,int epochs,bool check_accuracy)403 int NetTrain::CreateAndRunNetwork(const std::string &filename, const std::string &bb_filename, int train_session,
404                                   int epochs, bool check_accuracy) {
405   auto start_prepare_time = GetTimeUs();
406   Context context;
407   context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = FlagToBindMode(flags_->cpu_bind_mode_);
408   context.device_list_[0].device_info_.cpu_device_info_.enable_float16_ = flags_->enable_fp16_;
409   context.device_list_[0].device_type_ = mindspore::lite::DT_CPU;
410   context.thread_num_ = flags_->num_threads_;
411 
412   TrainCfg train_cfg;
413   if (flags_->loss_name_ != "") {
414     train_cfg.loss_name_ = flags_->loss_name_;
415   }
416   train_cfg.mix_precision_cfg_.is_raw_mix_precision_ = flags_->is_raw_mix_precision_;
417   std::unique_ptr<session::LiteSession> session;
418   if (train_session) {
419     session = CreateAndRunNetworkForTrain(filename, bb_filename, context, train_cfg, epochs);
420     if (session == nullptr) {
421       MS_LOG(ERROR) << "CreateAndRunNetworkForInference failed.";
422       return RET_ERROR;
423     }
424   } else {
425     session = CreateAndRunNetworkForInference(filename, context);
426     if (session == nullptr) {
427       MS_LOG(ERROR) << "CreateAndRunNetworkForInference failed.";
428       return RET_ERROR;
429     }
430   }
431 
432   if (!flags_->resize_dims_.empty()) {
433     auto ret = session->Resize(session->GetInputs(), flags_->resize_dims_);
434     if (ret != RET_OK) {
435       MS_LOG(ERROR) << "Input tensor resize failed.";
436       std::cout << "Input tensor resize failed.";
437       return ret;
438     }
439   }
440 
441   auto end_prepare_time = GetTimeUs();
442   MS_LOG(INFO) << "PrepareTime = " << ((end_prepare_time - start_prepare_time) / kTHOUSAND) << " ms";
443   std::cout << "PrepareTime = " << ((end_prepare_time - start_prepare_time) / kTHOUSAND) << " ms" << std::endl;
444   // Load input
445   MS_LOG(INFO) << "Load input data";
446   auto ms_inputs = session->GetInputs();
447   auto status = LoadInput(&ms_inputs);
448   if (status != RET_OK) {
449     MS_LOG(ERROR) << "Load input data error";
450     return status;
451   }
452 
453   if ((epochs > 0) && train_session) {
454     status = MarkPerformance(session);
455     if (status != RET_OK) {
456       MS_LOG(ERROR) << "Run MarkPerformance error: " << status;
457       std::cout << "Run MarkPerformance error: " << status << std::endl;
458       return status;
459     }
460     SaveModels(session);  // save file if flags are on
461   }
462   if (!flags_->data_file_.empty()) {
463     session->Eval();
464 
465     status = MarkAccuracy(session, check_accuracy);
466     if (status != RET_OK) {
467       MS_LOG(ERROR) << "Run MarkAccuracy error: " << status;
468       std::cout << "Run MarkAccuracy error: " << status << std::endl;
469       return status;
470     }
471   }
472   return RET_OK;
473 }
474 
RunNetTrain()475 int NetTrain::RunNetTrain() {
476   auto status = CreateAndRunNetwork(flags_->model_file_, flags_->bb_model_file_, true, flags_->epochs_);
477   if (status != RET_OK) {
478     MS_LOG(ERROR) << "CreateAndRunNetwork failed for model " << flags_->model_file_ << ". Status is " << status;
479     std::cout << "CreateAndRunNetwork failed for model " << flags_->model_file_ << ". Status is " << status
480               << std::endl;
481     return status;
482   }
483 
484   status = CheckExecutionOfSavedModels();  // re-initialize sessions according to flags
485   if (status != RET_OK) {
486     MS_LOG(ERROR) << "Run CheckExecute error: " << status;
487     std::cout << "Run CheckExecute error: " << status << std::endl;
488     return status;
489   }
490   return RET_OK;
491 }
492 
SaveModels(const std::unique_ptr<session::LiteSession> & session)493 int NetTrain::SaveModels(const std::unique_ptr<session::LiteSession> &session) {
494   if (!flags_->export_file_.empty()) {
495     if (flags_->bb_model_file_.empty()) {
496       auto status = session->Export(flags_->export_file_ + "_qt", lite::MT_TRAIN, lite::QT_WEIGHT);
497       if (status != RET_OK) {
498         MS_LOG(ERROR) << "Export quantized model error " << flags_->export_file_ + "_qt";
499         std::cout << "Export quantized model error " << flags_->export_file_ + "_qt" << std::endl;
500         return RET_ERROR;
501       }
502     }
503     auto status = session->Export(flags_->export_file_, lite::MT_TRAIN, lite::QT_NONE);
504     if (status != RET_OK) {
505       MS_LOG(ERROR) << "Export non quantized model error " << flags_->export_file_;
506       std::cout << "Export non quantized model error " << flags_->export_file_ << std::endl;
507       return RET_ERROR;
508     }
509   }
510   if (!flags_->inference_file_.empty()) {
511     auto status = session->Export(flags_->inference_file_ + "_qt", lite::MT_INFERENCE, lite::QT_WEIGHT);
512     if (status != RET_OK) {
513       MS_LOG(ERROR) << "Export quantized inference model error " << flags_->inference_file_ + "_qt";
514       std::cout << "Export quantized inference model error " << flags_->inference_file_ + "_qt" << std::endl;
515       return RET_ERROR;
516     }
517 
518     auto tick = GetTimeUs();
519     status = session->Export(flags_->inference_file_, lite::MT_INFERENCE, lite::QT_NONE);
520     if (status != RET_OK) {
521       MS_LOG(ERROR) << "Export non quantized inference model error " << flags_->inference_file_ + "_qt";
522       std::cout << "Export non quantized inference model error " << flags_->inference_file_ + "_qt" << std::endl;
523       return status;
524     }
525     std::cout << "ExportInference() execution time is " << GetTimeUs() - tick << "us\n";
526   }
527   return RET_OK;
528 }
529 
CheckExecutionOfSavedModels()530 int NetTrain::CheckExecutionOfSavedModels() {
531   int status = RET_OK;
532   if (!flags_->export_file_.empty()) {
533     status = NetTrain::CreateAndRunNetwork(flags_->export_file_, flags_->bb_model_file_, true, 0);
534     if (status != RET_OK) {
535       MS_LOG(ERROR) << "Run Exported model " << flags_->export_file_ << " error: " << status;
536       std::cout << "Run Exported model " << flags_->export_file_ << " error: " << status << std::endl;
537       return status;
538     }
539     if (flags_->bb_model_file_.empty()) {
540       status = NetTrain::CreateAndRunNetwork(flags_->export_file_ + "_qt", "", true, 0, false);
541       if (status != RET_OK) {
542         MS_LOG(ERROR) << "Run Exported model " << flags_->export_file_ << "_qt.ms error: " << status;
543         std::cout << "Run Exported model " << flags_->export_file_ << "_qt.ms error: " << status << std::endl;
544         return status;
545       }
546     }
547   }
548   if (!flags_->inference_file_.empty()) {
549     status = NetTrain::CreateAndRunNetwork(flags_->inference_file_, "", false, 0);
550     if (status != RET_OK) {
551       MS_LOG(ERROR) << "Running saved model " << flags_->inference_file_ << ".ms error: " << status;
552       std::cout << "Running saved model " << flags_->inference_file_ << ".ms error: " << status << std::endl;
553       return status;
554     }
555     status = NetTrain::CreateAndRunNetwork(flags_->inference_file_ + "_qt", "", false, 0, false);
556     if (status != RET_OK) {
557       MS_LOG(ERROR) << "Running saved model " << flags_->inference_file_ << "_qt.ms error: " << status;
558       std::cout << "Running saved model " << flags_->inference_file_ << "_qt.ms error: " << status << std::endl;
559       return status;
560     }
561   }
562   return status;
563 }
564 
CheckSum(mindspore::tensor::MSTensor * tensor,std::string node_type,int id,std::string in_out)565 void NetTrain::CheckSum(mindspore::tensor::MSTensor *tensor, std::string node_type, int id, std::string in_out) {
566   int tensor_size = tensor->ElementsNum();
567   void *data = tensor->MutableData();
568   TypeId type = tensor->data_type();
569   std::cout << node_type << " " << in_out << id << " shape=" << tensor->shape() << " sum=";
570   switch (type) {
571     case kNumberTypeFloat32:
572       TensorNan(reinterpret_cast<float *>(data), tensor_size);
573       std::cout << TensorSum<float>(data, tensor_size) << std::endl;
574       break;
575     case kNumberTypeInt32:
576       std::cout << TensorSum<int>(data, tensor_size) << std::endl;
577       break;
578 #ifdef ENABLE_FP16
579     case kNumberTypeFloat16:
580       std::cout << TensorSum<float16_t>(data, tensor_size) << std::endl;
581       break;
582 #endif
583     default:
584       std::cout << "unsupported type:" << type << std::endl;
585       break;
586   }
587 }
588 
InitCallbackParameter()589 int NetTrain::InitCallbackParameter() {
590   // before callback
591   before_call_back_ = [&](const std::vector<mindspore::tensor::MSTensor *> &before_inputs,
592                           const std::vector<mindspore::tensor::MSTensor *> &before_outputs,
593                           const mindspore::CallBackParam &callParam) {
594     if (before_inputs.empty()) {
595       MS_LOG(INFO) << "The num of beforeInputs is empty";
596     }
597     if (before_outputs.empty()) {
598       MS_LOG(INFO) << "The num of beforeOutputs is empty";
599     }
600     if (op_times_by_type_.find(callParam.node_type) == op_times_by_type_.end()) {
601       op_times_by_type_.insert(std::make_pair(callParam.node_type, std::make_pair(0, 0.0f)));
602     }
603     if (op_times_by_name_.find(callParam.node_name) == op_times_by_name_.end()) {
604       op_times_by_name_.insert(std::make_pair(callParam.node_name, std::make_pair(0, 0.0f)));
605     }
606     op_call_times_total_++;
607     op_begin_ = GetTimeUs();
608     if ((callParam.node_type == "Adam") || (callParam.node_type == "Assign") || callParam.node_type == "SGD") {
609       for (auto tensor : before_outputs) {
610         std::fill(reinterpret_cast<int8_t *>(tensor->MutableData()),
611                   reinterpret_cast<int8_t *>(tensor->MutableData()) + tensor->Size(), 0);
612       }
613     }
614     return true;
615   };
616 
617   // after callback
618   after_call_back_ = [&](const std::vector<mindspore::tensor::MSTensor *> &after_inputs,
619                          const std::vector<mindspore::tensor::MSTensor *> &after_outputs,
620                          const mindspore::CallBackParam &call_param) {
621     uint64_t opEnd = GetTimeUs();
622     if (after_inputs.empty()) {
623       MS_LOG(INFO) << "The num of after inputs is empty";
624     }
625     if (after_outputs.empty()) {
626       MS_LOG(INFO) << "The num of after outputs is empty";
627     }
628     float cost = static_cast<float>(opEnd - op_begin_) / 1000.0f;
629     op_cost_total_ += cost;
630     op_times_by_type_[call_param.node_type].first++;
631     op_times_by_type_[call_param.node_type].second += cost;
632     op_times_by_name_[call_param.node_name].first++;
633     op_times_by_name_[call_param.node_name].second += cost;
634     if (flags_->layer_checksum_) {
635       for (size_t i = 0; i < after_inputs.size(); i++) {
636         CheckSum(after_inputs.at(i), call_param.node_type, i, "in");
637       }
638       for (size_t i = 0; i < after_outputs.size(); i++) {
639         CheckSum(after_outputs.at(i), call_param.node_type, i, "out");
640       }
641       std::cout << std::endl;
642     }
643     return true;
644   };
645   return RET_OK;
646 }
647 
InitResizeDimsList()648 void NetTrainFlags::InitResizeDimsList() {
649   std::string content = this->resize_dims_in_;
650   std::vector<int> shape;
651   auto shape_strs = StrSplit(content, std::string(DELIM_COLON));
652   for (const auto &shape_str : shape_strs) {
653     shape.clear();
654     auto dim_strs = StrSplit(shape_str, std::string(DELIM_COMMA));
655     std::cout << "Resize Dims: ";
656     for (const auto &dim_str : dim_strs) {
657       std::cout << dim_str << " ";
658       shape.emplace_back(static_cast<int>(std::stoi(dim_str)));
659     }
660     std::cout << std::endl;
661     this->resize_dims_.emplace_back(shape);
662   }
663 }
664 
Init()665 int NetTrain::Init() {
666   if (this->flags_ == nullptr) {
667     return 1;
668   }
669   MS_LOG(INFO) << "ModelPath = " << this->flags_->model_file_;
670   MS_LOG(INFO) << "InDataPath = " << this->flags_->in_data_file_;
671   MS_LOG(INFO) << "InDataType = " << this->flags_->in_data_type_in_;
672   MS_LOG(INFO) << "Epochs = " << this->flags_->epochs_;
673   MS_LOG(INFO) << "AccuracyThreshold = " << this->flags_->accuracy_threshold_;
674   MS_LOG(INFO) << "WarmUpLoopCount = " << this->flags_->warm_up_loop_count_;
675   MS_LOG(INFO) << "NumThreads = " << this->flags_->num_threads_;
676   MS_LOG(INFO) << "expectedDataFile = " << this->flags_->data_file_;
677   MS_LOG(INFO) << "exportDataFile = " << this->flags_->export_file_;
678   MS_LOG(INFO) << "enableFp16 = " << this->flags_->enable_fp16_;
679   MS_LOG(INFO) << "virtualBatch = " << this->flags_->virtual_batch_;
680 
681   if (this->flags_->epochs_ < 0) {
682     MS_LOG(ERROR) << "epochs:" << this->flags_->epochs_ << " must be equal/greater than 0";
683     std::cerr << "epochs:" << this->flags_->epochs_ << " must be equal/greater than 0" << std::endl;
684     return RET_ERROR;
685   }
686 
687   if (this->flags_->num_threads_ < 1) {
688     MS_LOG(ERROR) << "numThreads:" << this->flags_->num_threads_ << " must be greater than 0";
689     std::cerr << "numThreads:" << this->flags_->num_threads_ << " must be greater than 0" << std::endl;
690     return RET_ERROR;
691   }
692 
693   this->flags_->in_data_type_ = this->flags_->in_data_type_in_ == "img" ? kImage : kBinary;
694 
695   if (flags_->in_data_file_.empty() && !flags_->data_file_.empty()) {
696     MS_LOG(ERROR) << "expectedDataFile not supported in case that inDataFile is not provided";
697     std::cerr << "expectedDataFile is not supported in case that inDataFile is not provided" << std::endl;
698     return RET_ERROR;
699   }
700 
701   if (flags_->in_data_file_.empty() && !flags_->export_file_.empty()) {
702     MS_LOG(ERROR) << "exportDataFile not supported in case that inDataFile is not provided";
703     std::cerr << "exportDataFile is not supported in case that inDataFile is not provided" << std::endl;
704     return RET_ERROR;
705   }
706 
707   if (flags_->model_file_.empty()) {
708     MS_LOG(ERROR) << "modelPath is required";
709     std::cerr << "modelPath is required" << std::endl;
710     return 1;
711   }
712 
713   if (flags_->time_profiling_) {
714     auto status = InitCallbackParameter();
715     if (status != RET_OK) {
716       MS_LOG(ERROR) << "Init callback Parameter failed.";
717       std::cerr << "Init callback Parameter failed." << std::endl;
718       return RET_ERROR;
719     }
720   }
721   flags_->InitResizeDimsList();
722   if (!flags_->resize_dims_.empty() && !flags_->input_data_list_.empty() &&
723       flags_->resize_dims_.size() != flags_->input_data_list_.size()) {
724     MS_LOG(ERROR) << "Size of input resizeDims should be equal to size of input inDataPath";
725     std::cerr << "Size of input resizeDims should be equal to size of input inDataPath" << std::endl;
726     return RET_ERROR;
727   }
728   return RET_OK;
729 }
730 
731 namespace {
732 constexpr int kNumToPrint = 5;
733 }
734 
PrintResult(const std::vector<std::string> & title,const std::map<std::string,std::pair<int,float>> & result)735 int NetTrain::PrintResult(const std::vector<std::string> &title,
736                           const std::map<std::string, std::pair<int, float>> &result) {
737   std::vector<size_t> columnLenMax(kFieldsToPrint);
738   std::vector<std::vector<std::string>> rows;
739 
740   for (auto &iter : result) {
741     std::string stringBuf[kFieldsToPrint];
742     std::vector<std::string> columns;
743     size_t len;
744 
745     len = iter.first.size();
746     if (len > columnLenMax.at(kField0)) {
747       columnLenMax.at(kField0) = len + kPrintOffset;
748     }
749     columns.push_back(iter.first);
750 
751     stringBuf[kField1] = to_string(iter.second.second / flags_->epochs_);
752     len = stringBuf[kField1].length();
753     if (len > columnLenMax.at(kField1)) {
754       columnLenMax.at(kField1) = len + kPrintOffset;
755     }
756     columns.emplace_back(stringBuf[kField1]);
757 
758     stringBuf[kField2] = to_string(iter.second.second / op_cost_total_);
759     len = stringBuf[kField2].length();
760     if (len > columnLenMax.at(kField2)) {
761       columnLenMax.at(kField2) = len + kPrintOffset;
762     }
763     columns.emplace_back(stringBuf[kField2]);
764 
765     stringBuf[kField3] = to_string(iter.second.first);
766     len = stringBuf[kField3].length();
767     if (len > columnLenMax.at(kField3)) {
768       columnLenMax.at(kField3) = len + kPrintOffset;
769     }
770     columns.emplace_back(stringBuf[kField3]);
771 
772     stringBuf[kField4] = to_string(iter.second.second);
773     len = stringBuf[kField4].length();
774     if (len > columnLenMax.at(kField4)) {
775       columnLenMax.at(kField4) = len + kPrintOffset;
776     }
777     columns.emplace_back(stringBuf[kField4]);
778 
779     rows.push_back(columns);
780   }
781 
782   printf("-------------------------------------------------------------------------\n");
783   for (int i = 0; i < kNumToPrint; i++) {
784     auto printBuf = title[i];
785     if (printBuf.size() > columnLenMax.at(i)) {
786       columnLenMax.at(i) = printBuf.size();
787     }
788     printBuf.resize(columnLenMax.at(i), ' ');
789     printf("%s\t", printBuf.c_str());
790   }
791   printf("\n");
792   for (size_t i = 0; i < rows.size(); i++) {
793     for (int j = 0; j < kNumToPrint; j++) {
794       auto printBuf = rows[i][j];
795       printBuf.resize(columnLenMax.at(j), ' ');
796       printf("%s\t", printBuf.c_str());
797     }
798     printf("\n");
799   }
800   return RET_OK;
801 }
802 
RunNetTrain(int argc,const char ** argv)803 int RunNetTrain(int argc, const char **argv) {
804   NetTrainFlags flags;
805   Option<std::string> err = flags.ParseFlags(argc, argv);
806 
807   if (err.IsSome()) {
808     std::cerr << err.Get() << std::endl;
809     std::cerr << flags.Usage() << std::endl;
810     return RET_ERROR;
811   }
812 
813   if (flags.help) {
814     std::cerr << flags.Usage() << std::endl;
815     return RET_OK;
816   }
817 
818   NetTrain net_trainer(&flags);
819   auto status = net_trainer.Init();
820   if (status != RET_OK) {
821     MS_LOG(ERROR) << "NetTrain init Error : " << status;
822     std::cerr << "NetTrain init Error : " << status << std::endl;
823     return RET_ERROR;
824   }
825 
826   status = net_trainer.RunNetTrain();
827   if (status != RET_OK) {
828     MS_LOG(ERROR) << "Run NetTrain "
829                   << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
830                   << " Failed : " << status;
831     std::cerr << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
832               << " Failed : " << status << std::endl;
833     return RET_ERROR;
834   }
835 
836   MS_LOG(INFO) << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
837                << " Success.";
838   std::cout << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
839             << " Success." << std::endl;
840   return RET_OK;
841 }
842 }  // namespace lite
843 }  // namespace mindspore
844