1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/train/train_loop.h"
18 #include <sys/stat.h>
19 #include <vector>
20 #include <memory>
21 #include <algorithm>
22 #include "include/errorcode.h"
23 #include "include/dataset/iterator.h"
24 #include "src/common/log_adapter.h"
25 #include "nnacl/op_base.h"
26
27 namespace mindspore {
28 namespace lite {
29 using dataset::Dataset;
30 using dataset::Iterator;
31 using dataset::MSTensorVec;
32 using session::RET_CONTINUE;
33 using session::RET_EXIT;
34 using session::RET_STOP_TRAINING;
35
~TrainLoop()36 TrainLoop::~TrainLoop() {}
37
Train(int epochs,Dataset * ds,std::vector<session::TrainLoopCallBack * > cbs,LoadDataFunc load_func)38 int TrainLoop::Train(int epochs, Dataset *ds, std::vector<session::TrainLoopCallBack *> cbs, LoadDataFunc load_func) {
39 MS_CHECK_TRUE_MSG(train_session_ != nullptr && ds != nullptr, RET_ERROR, "graph data cannot be nullptr");
40 MS_CHECK_GE(epochs, 0, RET_ERROR);
41 auto ret = train_session_->Train();
42 if (ret != RET_OK) {
43 MS_LOG(ERROR) << "TrainLoop train failed";
44 return RET_ERROR;
45 }
46 session::TrainLoopCallBackData cb_data(true, epoch_, train_session_, this);
47
48 if (load_func == nullptr) load_func = TrainLoop::LoadData;
49
50 for (auto cb : cbs) {
51 MS_CHECK_TRUE_MSG(cb != nullptr, RET_ERROR, "callback cannot be nullptr");
52 cb->Begin(cb_data);
53 }
54
55 std::shared_ptr<Iterator> iter = ds->CreateIterator();
56 MS_CHECK_TRUE_MSG(iter != nullptr, RET_ERROR, "iterator cannot be nullptr");
57 for (int i = 0; i < epochs; i++) {
58 cb_data.epoch_ = epoch_++;
59 for (auto cb : cbs) cb->EpochBegin(cb_data);
60
61 MSTensorVec row_vec;
62 int s = 0;
63
64 auto status = iter->GetNextRow(&row_vec);
65 if (status != Status::OK()) {
66 MS_LOG(ERROR) << "Get row failed";
67 return RET_ERROR;
68 }
69 while (!row_vec.empty()) {
70 ret = load_func(cb_data.session_->GetInputs(), &row_vec);
71 if (ret != RET_OK) break;
72 cb_data.step_ = s++;
73 for (auto cb : cbs) cb->StepBegin(cb_data);
74
75 ret = train_session_->RunGraph(before_cb_, after_cb_);
76 if (ret != RET_OK) {
77 MS_LOG(ERROR) << "Run Graph failed";
78 return RET_ERROR;
79 }
80 for (auto cb : cbs) cb->StepEnd(cb_data);
81 status = iter->GetNextRow(&row_vec);
82 if (status != Status::OK()) {
83 MS_LOG(ERROR) << "Get row failed";
84 return RET_ERROR;
85 }
86 }
87 int break_loop = false;
88 for (auto cb : cbs) {
89 ret = cb->EpochEnd(cb_data);
90 if (ret != RET_CONTINUE) {
91 if (ret == RET_EXIT) {
92 MS_LOG(ERROR) << "Error in TrainLoop callback";
93 return RET_ERROR;
94 }
95 if (ret == RET_STOP_TRAINING) {
96 break_loop = true;
97 }
98 }
99 }
100 if (break_loop) {
101 break;
102 }
103 }
104 iter->Stop();
105 for (auto cb : cbs) cb->End(cb_data);
106 return RET_OK;
107 }
108
Eval(Dataset * ds,std::vector<session::TrainLoopCallBack * > cbs,LoadDataFunc load_func,int max_steps)109 int TrainLoop::Eval(Dataset *ds, std::vector<session::TrainLoopCallBack *> cbs, LoadDataFunc load_func, int max_steps) {
110 MS_CHECK_TRUE_MSG(train_session_ != nullptr && ds != nullptr, RET_ERROR, "graph data cannot be nullptr");
111 auto ret = train_session_->Eval();
112 if (ret != RET_OK) {
113 MS_LOG(ERROR) << "TrainLoop train failed";
114 return RET_ERROR;
115 }
116 session::TrainLoopCallBackData cb_data(false, epoch_, train_session_, this);
117
118 if (load_func == nullptr) load_func = TrainLoop::LoadData;
119
120 for (auto metric : metrics_) {
121 MS_CHECK_TRUE_MSG(metric != nullptr, RET_ERROR, "metric cannot be nullptr");
122 metric->Clear();
123 }
124 for (auto cb : cbs) {
125 MS_CHECK_TRUE_MSG(cb != nullptr, RET_ERROR, "callback cannot be nullptr");
126 cb->Begin(cb_data);
127 }
128 for (auto cb : cbs) cb->EpochBegin(cb_data);
129
130 std::shared_ptr<Iterator> iter = ds->CreateIterator();
131 MS_CHECK_TRUE_MSG(iter != nullptr, RET_ERROR, "iterator cannot be nullptr");
132 MSTensorVec row_vec;
133 int s = 0;
134
135 auto status = iter->GetNextRow(&row_vec);
136 if (status != Status::OK()) {
137 MS_LOG(ERROR) << "Get row failed";
138 return RET_ERROR;
139 }
140 while (!row_vec.empty()) {
141 if (s >= max_steps) break;
142 ret = load_func(cb_data.session_->GetInputs(), &row_vec);
143 if (ret != RET_OK) break;
144
145 cb_data.step_ = ++s;
146 for (auto cb : cbs) cb->StepBegin(cb_data);
147
148 train_session_->RunGraph(before_cb_, after_cb_);
149 for (auto cb : cbs) cb->StepEnd(cb_data);
150
151 auto outputs = cb_data.session_->GetPredictions();
152 for (auto metric : metrics_) metric->Update(cb_data.session_->GetInputs(), outputs);
153 status = iter->GetNextRow(&row_vec);
154 if (status != Status::OK()) {
155 MS_LOG(ERROR) << "Get row failed";
156 return RET_ERROR;
157 }
158 }
159 iter->Stop();
160 for (auto cb : cbs) cb->EpochEnd(cb_data);
161 for (auto cb : cbs) cb->End(cb_data);
162
163 return RET_OK;
164 }
165
LoadData(std::vector<tensor::MSTensor * > inputs,dataset::MSTensorVec * row_vec)166 int TrainLoop::LoadData(std::vector<tensor::MSTensor *> inputs, dataset::MSTensorVec *row_vec) {
167 auto num_of_inputs = inputs.size();
168 if ((num_of_inputs == 0) || (row_vec == nullptr) || (num_of_inputs != row_vec->size())) {
169 return RET_STOP_TRAINING;
170 }
171
172 for (unsigned int i = 0; i < num_of_inputs; i++) {
173 auto *input_data = reinterpret_cast<unsigned char *>(inputs.at(i)->MutableData());
174 const auto *row_data = reinterpret_cast<const unsigned char *>(row_vec->at(i).MutableData());
175 auto data_size = row_vec->at(i).DataSize();
176 if (data_size != inputs.at(i)->Size()) {
177 MS_LOG(WARNING) << "Model Input tensor " << i << " size (" << inputs.at(i)->Size()
178 << ") does not match dataset size (" << data_size << ")\n";
179 return RET_STOP_TRAINING;
180 }
181 std::copy(row_data, row_data + data_size, input_data);
182 }
183 return RET_OK;
184 }
185 } // namespace lite
186
CreateTrainLoop(session::LiteSession * train_session)187 session::TrainLoop *session::TrainLoop::CreateTrainLoop(session::LiteSession *train_session) {
188 auto loop = new (std::nothrow) lite::TrainLoop(train_session);
189 return loop;
190 }
191 } // namespace mindspore
192