• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_LOOP_H_
17 #define MINDSPORE_LITE_SRC_TRAIN_TRAIN_LOOP_H_
18 #include <vector>
19 #include <string>
20 #include <tuple>
21 #include <memory>
22 #include <unordered_map>
23 #include "include/errorcode.h"
24 #include "include/train/train_loop.h"
25 #include "include/train/metrics.h"
26 #include "include/dataset/datasets.h"
27 #include "include/dataset/iterator.h"
28 #include "src/common/log_adapter.h"
29 
30 namespace mindspore {
31 namespace lite {
32 
33 class TrainLoop : virtual public session::TrainLoop {
34  public:
TrainLoop(session::LiteSession * session)35   explicit TrainLoop(session::LiteSession *session) : train_session_(session) {}
36 
train_session()37   const session::LiteSession *train_session() override { return train_session_; }
38 
Reset()39   int Reset() override {
40     epoch_ = 0;
41     return RET_OK;
42   }
43 
44   virtual ~TrainLoop();
45 
Init(std::vector<mindspore::session::Metrics * > metrics)46   int Init(std::vector<mindspore::session::Metrics *> metrics) override {
47     metrics_ = metrics;
48     return RET_OK;
49   }
50 
SetKernelCallBack(const KernelCallBack & before,const KernelCallBack & after)51   int SetKernelCallBack(const KernelCallBack &before, const KernelCallBack &after) override {
52     before_cb_ = before;
53     after_cb_ = after;
54     return RET_OK;
55   }
56 
57   int Train(int epochs, dataset::Dataset *dataset, std::vector<session::TrainLoopCallBack *> cbs,
58             LoadDataFunc load_func = nullptr) override;
59   int Eval(dataset::Dataset *dataset, std::vector<session::TrainLoopCallBack *> cbs, LoadDataFunc load_func = nullptr,
60            int max_steps = 0) override;
61 
GetMetrics()62   std::vector<mindspore::session::Metrics *> GetMetrics() override { return metrics_; }
63 
64  protected:
65   static int LoadData(std::vector<tensor::MSTensor *> inputs, dataset::MSTensorVec *dataset_vec);
66 
67   session::LiteSession *train_session_ = nullptr;
68   unsigned int epoch_ = 0;
69   KernelCallBack before_cb_ = nullptr;
70   KernelCallBack after_cb_ = nullptr;
71   int batch_size;
72   std::vector<mindspore::session::Metrics *> metrics_;
73 };
74 }  // namespace lite
75 }  // namespace mindspore
76 #endif  // MINDSPORE_LITE_SRC_TRAIN_TRAIN_LOOP_H_
77