• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_H_
17 #define MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_H_
18 #include <vector>
19 #include <string>
20 #include <tuple>
21 #include <unordered_map>
22 #include "include/train/train_loop_callback.h"
23 #include "include/train/metrics.h"
24 #include "src/litert/lite_session.h"
25 
26 namespace mindspore {
27 class MSTensor;
28 
29 namespace dataset {
30 class Dataset;
31 using MSTensorVec = std::vector<mindspore::MSTensor>;
32 }  // namespace dataset
33 
34 using LoadDataFunc = std::function<int(std::vector<lite::Tensor *> inputs, dataset::MSTensorVec *dataset_vec)>;
35 
36 namespace session {
37 
38 class TrainLoop {
39  public:
40   /// \brief Static method to create a TrainLoop object
41   ///
42   /// \param[in] train_session Train session object as return from CreateSession\CreateTransferSession API
43   ///
44   /// \return Pointer of MindSpore Lite TrainLoop
45   static lite::TrainLoop *CreateTrainLoop(lite::LiteSession *train_session);
46 
47   /// \brief Class destructor
48   virtual ~TrainLoop() = default;
49 
50   /// \brief Resets the epoch counter
51   ///
52   /// \return 0 on success or -1 in case of error
53   virtual int Reset() = 0;  // resets the epoch counter to 0.
54 
55   /// \brief Accessor to the LiteSession
56   ///
57   /// \return pointer of the train_session
58   const virtual lite::LiteSession *train_session() = 0;
59 
60   /// \brief Initialize object with metrics
61   ///
62   /// \param[in] verctor of metrics
63   ///
64   /// \return 0 on success or -1 in case of error
65   virtual int Init(std::vector<mindspore::session::Metrics *> metrics) = 0;
66 
67   /// \brief Accessor to TrainLoop metric objects
68   ///
69   /// \return vector of metrics
70   virtual std::vector<mindspore::session::Metrics *> GetMetrics() = 0;
71 
72   /// \brief Accessor to the Session KernelCallbacks
73   ///
74   /// \param[in] before Define a call_back_function to be called before running each node.
75   /// \param[in] after Define a call_back_function called after running each node.
76   ///
77   /// \return 0 on success or -1 in case of error
78   virtual int SetKernelCallBack(const KernelCallBack &before, const KernelCallBack &after) = 0;
79 
80   /// \brief Performs the training Loop
81   ///
82   /// \param[in] epochs The number of epochs to run
83   /// \param[in] dataset Pointer to MindData Dataset object
84   /// \param[in] cbs A vector of TrainLoopCallBack objects
85   /// \param[in] load_func a function that load (and can manipulate) data from Minddata Dataset array into model
86   ///
87   /// \return 0 on success or -1 in case of error
88   virtual int Train(int epochs, mindspore::dataset::Dataset *dataset, std::vector<lite::TrainLoopCallBack *> cbs,
89                     LoadDataFunc load_func) = 0;
90 
91   /// \brief Performs loop over all data in Eval Mode
92   ///
93   /// \param[in] dataset Pointer to MindData Dataset object
94   /// \param[in] cbs A vector of TrainLoopCallBack objects
95   /// \param[in] load_func a function that load (and can manipulate) data from Minddata Dataset array into model
96   /// \param[in] max_steps (with default = INT_MAX the method iterates all dataset)
97   ///
98   /// \return 0 on success or -1 in case of error
99   virtual int Eval(mindspore::dataset::Dataset *dataset, std::vector<lite::TrainLoopCallBack *> cbs,
100                    LoadDataFunc load_func, int max_steps) = 0;
101 };
102 }  // namespace session
103 }  // namespace mindspore
104 #endif  // MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_H_
105