• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/train/train_loop.h"
18 #include <sys/stat.h>
19 #include <vector>
20 #include <memory>
21 #include <algorithm>
22 #include "include/errorcode.h"
23 #include "include/dataset/iterator.h"
24 #include "src/common/log_adapter.h"
25 #include "nnacl/op_base.h"
26 
27 namespace mindspore {
28 namespace lite {
29 using dataset::Dataset;
30 using dataset::Iterator;
31 using dataset::MSTensorVec;
32 
~TrainLoop()33 TrainLoop::~TrainLoop() {}
34 
Train(int epochs,Dataset * ds,std::vector<TrainLoopCallBack * > cbs,LoadDataFunc load_func)35 int TrainLoop::Train(int epochs, Dataset *ds, std::vector<TrainLoopCallBack *> cbs, LoadDataFunc load_func) {
36   MS_CHECK_TRUE_MSG(train_session_ != nullptr && ds != nullptr, RET_ERROR, "graph data cannot be nullptr");
37   MS_CHECK_GE(epochs, 0, RET_ERROR);
38   auto ret = train_session_->Train();
39   if (ret != RET_OK) {
40     MS_LOG(ERROR) << "TrainLoop train failed";
41     return RET_ERROR;
42   }
43   TrainLoopCallBackData cb_data(true, epoch_, train_session_, this);
44 
45   if (load_func == nullptr) load_func = TrainLoop::LoadData;
46 
47   for (auto cb : cbs) {
48     MS_CHECK_TRUE_MSG(cb != nullptr, RET_ERROR, "callback cannot be nullptr");
49     cb->Begin(cb_data);
50   }
51 
52   std::shared_ptr<Iterator> iter = ds->CreateIterator();
53   MS_CHECK_TRUE_MSG(iter != nullptr, RET_ERROR, "iterator cannot be nullptr");
54   for (int i = 0; i < epochs; i++) {
55     cb_data.epoch_ = epoch_++;
56     for (auto cb : cbs) cb->EpochBegin(cb_data);
57 
58     MSTensorVec row_vec;
59     int s = 0;
60 
61     auto status = iter->GetNextRow(&row_vec);
62     if (status != Status::OK()) {
63       MS_LOG(ERROR) << "Get row failed";
64       return RET_ERROR;
65     }
66     while (!row_vec.empty()) {
67       ret = load_func(cb_data.session_->GetInputs(), &row_vec);
68       if (ret != RET_OK) break;
69       cb_data.step_ = s++;
70       for (auto cb : cbs) cb->StepBegin(cb_data);
71 
72       ret = train_session_->RunGraph(before_cb_, after_cb_);
73       if (ret != RET_OK) {
74         MS_LOG(ERROR) << "Run Graph failed";
75         return RET_ERROR;
76       }
77       for (auto cb : cbs) cb->StepEnd(cb_data);
78       status = iter->GetNextRow(&row_vec);
79       if (status != Status::OK()) {
80         MS_LOG(ERROR) << "Get row failed";
81         return RET_ERROR;
82       }
83     }
84     bool break_loop = false;
85     for (auto cb : cbs) {
86       ret = cb->EpochEnd(cb_data);
87       if (ret != RET_CONTINUE) {
88         if (ret == RET_EXIT) {
89           MS_LOG(ERROR) << "Error in TrainLoop callback";
90           return RET_ERROR;
91         }
92         if (ret == RET_STOP_TRAINING) {
93           break_loop = true;
94         }
95       }
96     }
97     if (break_loop) {
98       break;
99     }
100   }
101   iter->Stop();
102   for (auto cb : cbs) cb->End(cb_data);
103   return RET_OK;
104 }
105 
Eval(Dataset * ds,std::vector<TrainLoopCallBack * > cbs,LoadDataFunc load_func,int max_steps)106 int TrainLoop::Eval(Dataset *ds, std::vector<TrainLoopCallBack *> cbs, LoadDataFunc load_func, int max_steps) {
107   MS_CHECK_TRUE_MSG(train_session_ != nullptr && ds != nullptr, RET_ERROR, "graph data cannot be nullptr");
108   auto ret = train_session_->Eval();
109   if (ret != RET_OK) {
110     MS_LOG(ERROR) << "TrainLoop train failed";
111     return RET_ERROR;
112   }
113   TrainLoopCallBackData cb_data(false, epoch_, train_session_, this);
114 
115   if (load_func == nullptr) load_func = TrainLoop::LoadData;
116 
117   for (auto metric : metrics_) {
118     MS_CHECK_TRUE_MSG(metric != nullptr, RET_ERROR, "metric cannot be nullptr");
119     metric->Clear();
120   }
121   for (auto cb : cbs) {
122     MS_CHECK_TRUE_MSG(cb != nullptr, RET_ERROR, "callback cannot be nullptr");
123     cb->Begin(cb_data);
124   }
125   for (auto cb : cbs) cb->EpochBegin(cb_data);
126 
127   std::shared_ptr<Iterator> iter = ds->CreateIterator();
128   MS_CHECK_TRUE_MSG(iter != nullptr, RET_ERROR, "iterator cannot be nullptr");
129   MSTensorVec row_vec;
130   int s = 0;
131 
132   auto status = iter->GetNextRow(&row_vec);
133   if (status != Status::OK()) {
134     MS_LOG(ERROR) << "Get row failed";
135     return RET_ERROR;
136   }
137   while (!row_vec.empty()) {
138     if (s >= max_steps) break;
139     ret = load_func(cb_data.session_->GetInputs(), &row_vec);
140     if (ret != RET_OK) break;
141 
142     cb_data.step_ = ++s;
143     for (auto cb : cbs) cb->StepBegin(cb_data);
144 
145     train_session_->RunGraph(before_cb_, after_cb_);
146     for (auto cb : cbs) cb->StepEnd(cb_data);
147 
148     auto outputs = cb_data.session_->GetPredictions();
149     for (auto metric : metrics_) metric->Update(cb_data.session_->GetInputs(), outputs);
150     status = iter->GetNextRow(&row_vec);
151     if (status != Status::OK()) {
152       MS_LOG(ERROR) << "Get row failed";
153       return RET_ERROR;
154     }
155   }
156   iter->Stop();
157   for (auto cb : cbs) cb->EpochEnd(cb_data);
158   for (auto cb : cbs) cb->End(cb_data);
159 
160   return RET_OK;
161 }
162 
LoadData(std::vector<lite::Tensor * > inputs,dataset::MSTensorVec * row_vec)163 int TrainLoop::LoadData(std::vector<lite::Tensor *> inputs, dataset::MSTensorVec *row_vec) {
164   auto num_of_inputs = inputs.size();
165   if ((num_of_inputs == 0) || (row_vec == nullptr) || (num_of_inputs != row_vec->size())) {
166     return RET_STOP_TRAINING;
167   }
168 
169   for (size_t i = 0; i < num_of_inputs; i++) {
170     auto *input_data = reinterpret_cast<unsigned char *>(inputs.at(i)->MutableData());
171     const auto *row_data = reinterpret_cast<const unsigned char *>(row_vec->at(i).MutableData());
172     auto data_size = row_vec->at(i).DataSize();
173     if (data_size != inputs.at(i)->Size()) {
174       MS_LOG(WARNING) << "Model Input tensor " << i << " size (" << inputs.at(i)->Size()
175                       << ") does not match dataset size (" << data_size << ")\n";
176       return RET_STOP_TRAINING;
177     }
178     std::copy(row_data, row_data + data_size, input_data);
179   }
180   return RET_OK;
181 }
182 }  // namespace lite
183 
CreateTrainLoop(lite::LiteSession * train_session)184 lite::TrainLoop *session::TrainLoop::CreateTrainLoop(lite::LiteSession *train_session) {
185   auto loop = new (std::nothrow) lite::TrainLoop(train_session);
186   return loop;
187 }
188 }  // namespace mindspore
189