• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/train/train_loop.h"
18 #include <sys/stat.h>
19 #include <vector>
20 #include <memory>
21 #include <algorithm>
22 #include "include/errorcode.h"
23 #include "include/dataset/iterator.h"
24 #include "src/common/log_adapter.h"
25 #include "nnacl/op_base.h"
26 
27 namespace mindspore {
28 namespace lite {
29 using dataset::Dataset;
30 using dataset::Iterator;
31 using dataset::MSTensorVec;
32 using session::RET_CONTINUE;
33 using session::RET_EXIT;
34 using session::RET_STOP_TRAINING;
35 
~TrainLoop()36 TrainLoop::~TrainLoop() {}
37 
Train(int epochs,Dataset * ds,std::vector<session::TrainLoopCallBack * > cbs,LoadDataFunc load_func)38 int TrainLoop::Train(int epochs, Dataset *ds, std::vector<session::TrainLoopCallBack *> cbs, LoadDataFunc load_func) {
39   MS_CHECK_TRUE_MSG(train_session_ != nullptr && ds != nullptr, RET_ERROR, "graph data cannot be nullptr");
40   MS_CHECK_GE(epochs, 0, RET_ERROR);
41   auto ret = train_session_->Train();
42   if (ret != RET_OK) {
43     MS_LOG(ERROR) << "TrainLoop train failed";
44     return RET_ERROR;
45   }
46   session::TrainLoopCallBackData cb_data(true, epoch_, train_session_, this);
47 
48   if (load_func == nullptr) load_func = TrainLoop::LoadData;
49 
50   for (auto cb : cbs) {
51     MS_CHECK_TRUE_MSG(cb != nullptr, RET_ERROR, "callback cannot be nullptr");
52     cb->Begin(cb_data);
53   }
54 
55   std::shared_ptr<Iterator> iter = ds->CreateIterator();
56   MS_CHECK_TRUE_MSG(iter != nullptr, RET_ERROR, "iterator cannot be nullptr");
57   for (int i = 0; i < epochs; i++) {
58     cb_data.epoch_ = epoch_++;
59     for (auto cb : cbs) cb->EpochBegin(cb_data);
60 
61     MSTensorVec row_vec;
62     int s = 0;
63 
64     auto status = iter->GetNextRow(&row_vec);
65     if (status != Status::OK()) {
66       MS_LOG(ERROR) << "Get row failed";
67       return RET_ERROR;
68     }
69     while (!row_vec.empty()) {
70       ret = load_func(cb_data.session_->GetInputs(), &row_vec);
71       if (ret != RET_OK) break;
72       cb_data.step_ = s++;
73       for (auto cb : cbs) cb->StepBegin(cb_data);
74 
75       ret = train_session_->RunGraph(before_cb_, after_cb_);
76       if (ret != RET_OK) {
77         MS_LOG(ERROR) << "Run Graph failed";
78         return RET_ERROR;
79       }
80       for (auto cb : cbs) cb->StepEnd(cb_data);
81       status = iter->GetNextRow(&row_vec);
82       if (status != Status::OK()) {
83         MS_LOG(ERROR) << "Get row failed";
84         return RET_ERROR;
85       }
86     }
87     int break_loop = false;
88     for (auto cb : cbs) {
89       ret = cb->EpochEnd(cb_data);
90       if (ret != RET_CONTINUE) {
91         if (ret == RET_EXIT) {
92           MS_LOG(ERROR) << "Error in TrainLoop callback";
93           return RET_ERROR;
94         }
95         if (ret == RET_STOP_TRAINING) {
96           break_loop = true;
97         }
98       }
99     }
100     if (break_loop) {
101       break;
102     }
103   }
104   iter->Stop();
105   for (auto cb : cbs) cb->End(cb_data);
106   return RET_OK;
107 }
108 
Eval(Dataset * ds,std::vector<session::TrainLoopCallBack * > cbs,LoadDataFunc load_func,int max_steps)109 int TrainLoop::Eval(Dataset *ds, std::vector<session::TrainLoopCallBack *> cbs, LoadDataFunc load_func, int max_steps) {
110   MS_CHECK_TRUE_MSG(train_session_ != nullptr && ds != nullptr, RET_ERROR, "graph data cannot be nullptr");
111   auto ret = train_session_->Eval();
112   if (ret != RET_OK) {
113     MS_LOG(ERROR) << "TrainLoop train failed";
114     return RET_ERROR;
115   }
116   session::TrainLoopCallBackData cb_data(false, epoch_, train_session_, this);
117 
118   if (load_func == nullptr) load_func = TrainLoop::LoadData;
119 
120   for (auto metric : metrics_) {
121     MS_CHECK_TRUE_MSG(metric != nullptr, RET_ERROR, "metric cannot be nullptr");
122     metric->Clear();
123   }
124   for (auto cb : cbs) {
125     MS_CHECK_TRUE_MSG(cb != nullptr, RET_ERROR, "callback cannot be nullptr");
126     cb->Begin(cb_data);
127   }
128   for (auto cb : cbs) cb->EpochBegin(cb_data);
129 
130   std::shared_ptr<Iterator> iter = ds->CreateIterator();
131   MS_CHECK_TRUE_MSG(iter != nullptr, RET_ERROR, "iterator cannot be nullptr");
132   MSTensorVec row_vec;
133   int s = 0;
134 
135   auto status = iter->GetNextRow(&row_vec);
136   if (status != Status::OK()) {
137     MS_LOG(ERROR) << "Get row failed";
138     return RET_ERROR;
139   }
140   while (!row_vec.empty()) {
141     if (s >= max_steps) break;
142     ret = load_func(cb_data.session_->GetInputs(), &row_vec);
143     if (ret != RET_OK) break;
144 
145     cb_data.step_ = ++s;
146     for (auto cb : cbs) cb->StepBegin(cb_data);
147 
148     train_session_->RunGraph(before_cb_, after_cb_);
149     for (auto cb : cbs) cb->StepEnd(cb_data);
150 
151     auto outputs = cb_data.session_->GetPredictions();
152     for (auto metric : metrics_) metric->Update(cb_data.session_->GetInputs(), outputs);
153     status = iter->GetNextRow(&row_vec);
154     if (status != Status::OK()) {
155       MS_LOG(ERROR) << "Get row failed";
156       return RET_ERROR;
157     }
158   }
159   iter->Stop();
160   for (auto cb : cbs) cb->EpochEnd(cb_data);
161   for (auto cb : cbs) cb->End(cb_data);
162 
163   return RET_OK;
164 }
165 
LoadData(std::vector<tensor::MSTensor * > inputs,dataset::MSTensorVec * row_vec)166 int TrainLoop::LoadData(std::vector<tensor::MSTensor *> inputs, dataset::MSTensorVec *row_vec) {
167   auto num_of_inputs = inputs.size();
168   if ((num_of_inputs == 0) || (row_vec == nullptr) || (num_of_inputs != row_vec->size())) {
169     return RET_STOP_TRAINING;
170   }
171 
172   for (unsigned int i = 0; i < num_of_inputs; i++) {
173     auto *input_data = reinterpret_cast<unsigned char *>(inputs.at(i)->MutableData());
174     const auto *row_data = reinterpret_cast<const unsigned char *>(row_vec->at(i).MutableData());
175     auto data_size = row_vec->at(i).DataSize();
176     if (data_size != inputs.at(i)->Size()) {
177       MS_LOG(WARNING) << "Model Input tensor " << i << " size (" << inputs.at(i)->Size()
178                       << ") does not match dataset size (" << data_size << ")\n";
179       return RET_STOP_TRAINING;
180     }
181     std::copy(row_data, row_data + data_size, input_data);
182   }
183   return RET_OK;
184 }
185 }  // namespace lite
186 
CreateTrainLoop(session::LiteSession * train_session)187 session::TrainLoop *session::TrainLoop::CreateTrainLoop(session::LiteSession *train_session) {
188   auto loop = new (std::nothrow) lite::TrainLoop(train_session);
189   return loop;
190 }
191 }  // namespace mindspore
192