1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/train/train_loop.h"
18 #include <sys/stat.h>
19 #include <vector>
20 #include <memory>
21 #include <algorithm>
22 #include "include/errorcode.h"
23 #include "include/dataset/iterator.h"
24 #include "src/common/log_adapter.h"
25 #include "nnacl/op_base.h"
26
27 namespace mindspore {
28 namespace lite {
29 using dataset::Dataset;
30 using dataset::Iterator;
31 using dataset::MSTensorVec;
32
~TrainLoop()33 TrainLoop::~TrainLoop() {}
34
Train(int epochs,Dataset * ds,std::vector<TrainLoopCallBack * > cbs,LoadDataFunc load_func)35 int TrainLoop::Train(int epochs, Dataset *ds, std::vector<TrainLoopCallBack *> cbs, LoadDataFunc load_func) {
36 MS_CHECK_TRUE_MSG(train_session_ != nullptr && ds != nullptr, RET_ERROR, "graph data cannot be nullptr");
37 MS_CHECK_GE(epochs, 0, RET_ERROR);
38 auto ret = train_session_->Train();
39 if (ret != RET_OK) {
40 MS_LOG(ERROR) << "TrainLoop train failed";
41 return RET_ERROR;
42 }
43 TrainLoopCallBackData cb_data(true, epoch_, train_session_, this);
44
45 if (load_func == nullptr) load_func = TrainLoop::LoadData;
46
47 for (auto cb : cbs) {
48 MS_CHECK_TRUE_MSG(cb != nullptr, RET_ERROR, "callback cannot be nullptr");
49 cb->Begin(cb_data);
50 }
51
52 std::shared_ptr<Iterator> iter = ds->CreateIterator();
53 MS_CHECK_TRUE_MSG(iter != nullptr, RET_ERROR, "iterator cannot be nullptr");
54 for (int i = 0; i < epochs; i++) {
55 cb_data.epoch_ = epoch_++;
56 for (auto cb : cbs) cb->EpochBegin(cb_data);
57
58 MSTensorVec row_vec;
59 int s = 0;
60
61 auto status = iter->GetNextRow(&row_vec);
62 if (status != Status::OK()) {
63 MS_LOG(ERROR) << "Get row failed";
64 return RET_ERROR;
65 }
66 while (!row_vec.empty()) {
67 ret = load_func(cb_data.session_->GetInputs(), &row_vec);
68 if (ret != RET_OK) break;
69 cb_data.step_ = s++;
70 for (auto cb : cbs) cb->StepBegin(cb_data);
71
72 ret = train_session_->RunGraph(before_cb_, after_cb_);
73 if (ret != RET_OK) {
74 MS_LOG(ERROR) << "Run Graph failed";
75 return RET_ERROR;
76 }
77 for (auto cb : cbs) cb->StepEnd(cb_data);
78 status = iter->GetNextRow(&row_vec);
79 if (status != Status::OK()) {
80 MS_LOG(ERROR) << "Get row failed";
81 return RET_ERROR;
82 }
83 }
84 bool break_loop = false;
85 for (auto cb : cbs) {
86 ret = cb->EpochEnd(cb_data);
87 if (ret != RET_CONTINUE) {
88 if (ret == RET_EXIT) {
89 MS_LOG(ERROR) << "Error in TrainLoop callback";
90 return RET_ERROR;
91 }
92 if (ret == RET_STOP_TRAINING) {
93 break_loop = true;
94 }
95 }
96 }
97 if (break_loop) {
98 break;
99 }
100 }
101 iter->Stop();
102 for (auto cb : cbs) cb->End(cb_data);
103 return RET_OK;
104 }
105
Eval(Dataset * ds,std::vector<TrainLoopCallBack * > cbs,LoadDataFunc load_func,int max_steps)106 int TrainLoop::Eval(Dataset *ds, std::vector<TrainLoopCallBack *> cbs, LoadDataFunc load_func, int max_steps) {
107 MS_CHECK_TRUE_MSG(train_session_ != nullptr && ds != nullptr, RET_ERROR, "graph data cannot be nullptr");
108 auto ret = train_session_->Eval();
109 if (ret != RET_OK) {
110 MS_LOG(ERROR) << "TrainLoop train failed";
111 return RET_ERROR;
112 }
113 TrainLoopCallBackData cb_data(false, epoch_, train_session_, this);
114
115 if (load_func == nullptr) load_func = TrainLoop::LoadData;
116
117 for (auto metric : metrics_) {
118 MS_CHECK_TRUE_MSG(metric != nullptr, RET_ERROR, "metric cannot be nullptr");
119 metric->Clear();
120 }
121 for (auto cb : cbs) {
122 MS_CHECK_TRUE_MSG(cb != nullptr, RET_ERROR, "callback cannot be nullptr");
123 cb->Begin(cb_data);
124 }
125 for (auto cb : cbs) cb->EpochBegin(cb_data);
126
127 std::shared_ptr<Iterator> iter = ds->CreateIterator();
128 MS_CHECK_TRUE_MSG(iter != nullptr, RET_ERROR, "iterator cannot be nullptr");
129 MSTensorVec row_vec;
130 int s = 0;
131
132 auto status = iter->GetNextRow(&row_vec);
133 if (status != Status::OK()) {
134 MS_LOG(ERROR) << "Get row failed";
135 return RET_ERROR;
136 }
137 while (!row_vec.empty()) {
138 if (s >= max_steps) break;
139 ret = load_func(cb_data.session_->GetInputs(), &row_vec);
140 if (ret != RET_OK) break;
141
142 cb_data.step_ = ++s;
143 for (auto cb : cbs) cb->StepBegin(cb_data);
144
145 train_session_->RunGraph(before_cb_, after_cb_);
146 for (auto cb : cbs) cb->StepEnd(cb_data);
147
148 auto outputs = cb_data.session_->GetPredictions();
149 for (auto metric : metrics_) metric->Update(cb_data.session_->GetInputs(), outputs);
150 status = iter->GetNextRow(&row_vec);
151 if (status != Status::OK()) {
152 MS_LOG(ERROR) << "Get row failed";
153 return RET_ERROR;
154 }
155 }
156 iter->Stop();
157 for (auto cb : cbs) cb->EpochEnd(cb_data);
158 for (auto cb : cbs) cb->End(cb_data);
159
160 return RET_OK;
161 }
162
LoadData(std::vector<lite::Tensor * > inputs,dataset::MSTensorVec * row_vec)163 int TrainLoop::LoadData(std::vector<lite::Tensor *> inputs, dataset::MSTensorVec *row_vec) {
164 auto num_of_inputs = inputs.size();
165 if ((num_of_inputs == 0) || (row_vec == nullptr) || (num_of_inputs != row_vec->size())) {
166 return RET_STOP_TRAINING;
167 }
168
169 for (size_t i = 0; i < num_of_inputs; i++) {
170 auto *input_data = reinterpret_cast<unsigned char *>(inputs.at(i)->MutableData());
171 const auto *row_data = reinterpret_cast<const unsigned char *>(row_vec->at(i).MutableData());
172 auto data_size = row_vec->at(i).DataSize();
173 if (data_size != inputs.at(i)->Size()) {
174 MS_LOG(WARNING) << "Model Input tensor " << i << " size (" << inputs.at(i)->Size()
175 << ") does not match dataset size (" << data_size << ")\n";
176 return RET_STOP_TRAINING;
177 }
178 std::copy(row_data, row_data + data_size, input_data);
179 }
180 return RET_OK;
181 }
182 } // namespace lite
183
CreateTrainLoop(lite::LiteSession * train_session)184 lite::TrainLoop *session::TrainLoop::CreateTrainLoop(lite::LiteSession *train_session) {
185 auto loop = new (std::nothrow) lite::TrainLoop(train_session);
186 return loop;
187 }
188 } // namespace mindspore
189