• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_CALLBACK_H_
17 #define MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_CALLBACK_H_
18 #include <vector>
19 #include <string>
20 #include <tuple>
21 #include <utility>
22 #include <unordered_map>
23 #include "include/api/callback/callback.h"
24 
25 namespace mindspore {
26 namespace lite {
27 class LiteSession;
28 class TrainLoop;
29 
30 struct TrainLoopCallBackData {
TrainLoopCallBackDataTrainLoopCallBackData31   TrainLoopCallBackData(bool train_mode, unsigned int epoch, LiteSession *session, TrainLoop *loop)
32       : train_mode_(train_mode), epoch_(epoch), session_(session), loop_(loop) {}
33 
34   bool train_mode_;       /**< training mode of LiteSession object */
35   unsigned int epoch_;    /**< the current training epoch (starts at 0) */
36   unsigned int step_ = 0; /**< the current step within the epoch */
37   LiteSession *session_;  /**< pointer to the LiteSession */
38   TrainLoop *loop_;
39 };
40 
41 constexpr int RET_CONTINUE = 0;
42 constexpr int RET_STOP_TRAINING = 1;
43 constexpr int RET_EXIT = 2;
44 
45 class TrainLoopCallBack {
46  public:
47   virtual ~TrainLoopCallBack() = default;
48 
49   /// \brief This method is called once before the network executing
50   ///
51   /// \param[in] cb_data info about current execution
Begin(const TrainLoopCallBackData & cb_data)52   virtual void Begin(const TrainLoopCallBackData &cb_data) {}
53 
54   /// \brief This method is called once following the network execution
55   ///
56   /// \param[in] cb_data info about current execution
End(const TrainLoopCallBackData & cb_data)57   virtual void End(const TrainLoopCallBackData &cb_data) {}
58 
59   /// \brief This method is called at the beginning of each epoch
60   ///
61   /// \param[in] cb_data info about current execution
EpochBegin(const TrainLoopCallBackData & cb_data)62   virtual void EpochBegin(const TrainLoopCallBackData &cb_data) {}
63 
64   /// \brief This method is called after the run of each epoch
65   ///
66   /// \param[in] cb_data info about current execution
67   ///
68   /// \return indication if to continue in the train loop:
69   ///         RET_CONTINUE -- continue training
70   ///         RET_STOP_TRAINING -- stop training (e.g., due to achieved accuracy)
71   ///         RET_EXIT -- Exit training (due to error of some sort)
EpochEnd(const TrainLoopCallBackData & cb_data)72   virtual int EpochEnd(const TrainLoopCallBackData &cb_data) { return RET_CONTINUE; }
73 
74   /// \brief This method is called at the beginning of each step
75   ///
76   /// \param[in] cb_data info about current execution
StepBegin(const TrainLoopCallBackData & cb_data)77   virtual void StepBegin(const TrainLoopCallBackData &cb_data) {}
78 
79   /// \brief This method is called after each step is ran
80   ///
81   /// \param[in] cb_data info about current execution
StepEnd(const TrainLoopCallBackData & cb_data)82   virtual void StepEnd(const TrainLoopCallBackData &cb_data) {}
83 };
84 
85 }  // namespace lite
86 }  // namespace mindspore
87 #endif  // MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_CALLBACK_H_
88