• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_H_
17 #define MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_H_
18 #include <vector>
19 #include <string>
20 #include <tuple>
21 #include <climits>
22 #include <unordered_map>
23 #include "include/train/train_loop_callback.h"
24 #include "include/train/metrics.h"
25 #include "include/lite_session.h"
26 
27 namespace mindspore {
28 class MSTensor;
29 
30 namespace dataset {
31 class Dataset;
32 using MSTensorVec = std::vector<mindspore::MSTensor>;
33 }  // namespace dataset
34 
35 using LoadDataFunc = std::function<int(std::vector<tensor::MSTensor *> inputs, dataset::MSTensorVec *dataset_vec)>;
36 
37 namespace session {
38 
39 class TrainLoop {
40  public:
41   /// \brief Static method to create a TrainLoop object
42   ///
43   /// \param[in] train_session Train session object as return from CreateSession\CreateTransferSession API
44   ///
45   /// \return Pointer of MindSpore Lite TrainLoop
46   static TrainLoop *CreateTrainLoop(session::LiteSession *train_session);
47 
48   /// \brief Class destructor
49   virtual ~TrainLoop() = default;
50 
51   /// \brief Resets the epoch counter
52   ///
53   /// \return 0 on success or -1 in case of error
54   virtual int Reset() = 0;  // resets the epoch counter to 0.
55 
56   /// \brief Accessor to the LiteSession
57   ///
58   /// \return pointer of the train_session
59   const virtual session::LiteSession *train_session() = 0;
60 
61   /// \brief Initialize object with metrics
62   ///
63   /// \param[in] verctor of metrics
64   ///
65   /// \return 0 on success or -1 in case of error
66   virtual int Init(std::vector<mindspore::session::Metrics *> metrics) = 0;
67 
68   /// \brief Accessor to TrainLoop metric objects
69   ///
70   /// \return vector of metrics
71   virtual std::vector<mindspore::session::Metrics *> GetMetrics() = 0;
72 
73   /// \brief Accessor to the Session KernelCallbacks
74   ///
75   /// \param[in] before Define a call_back_function to be called before running each node.
76   /// \param[in] after Define a call_back_function called after running each node.
77   ///
78   /// \return 0 on success or -1 in case of error
79   virtual int SetKernelCallBack(const KernelCallBack &before, const KernelCallBack &after) = 0;
80 
81   /// \brief Performs the training Loop
82   ///
83   /// \param[in] epoch The number of epochs to run
84   /// \param[in] dataset Pointer to MindData Dataset object
85   /// \param[in] cbs A vector of TrainLoopCallBack objects
86   /// \param[in] load_func a function that load (and can manipulate) data from Minddata Dataset array into model
87   ///
88   /// \return 0 on success or -1 in case of error
89   virtual int Train(int epochs, mindspore::dataset::Dataset *dataset, std::vector<TrainLoopCallBack *> cbs,
90                     LoadDataFunc load_func = nullptr) = 0;
91 
92   /// \brief Performs loop over all data in Eval Mode
93   ///
94   /// \param[in] dataset Pointer to MindData Dataset object
95   /// \param[in] cbs A vector of TrainLoopCallBack objects
96   /// \param[in] load_func a function that load (and can manipulate) data from Minddata Dataset array into model
97   /// \param[in] max_steps (with default = INT_MAX the method iterates all dataset)
98   ///
99   /// \return 0 on success or -1 in case of error
100   virtual int Eval(mindspore::dataset::Dataset *dataset, std::vector<TrainLoopCallBack *> cbs,
101                    LoadDataFunc load_func = nullptr, int max_steps = INT_MAX) = 0;
102 };
103 }  // namespace session
104 }  // namespace mindspore
105 #endif  // MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_H_
106