• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_CALLBACK_H_
17 #define MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_CALLBACK_H_
18 #include <vector>
19 #include <string>
20 #include <tuple>
21 #include <unordered_map>
22 
23 namespace mindspore {
24 namespace session {
25 
26 class LiteSession;
27 class TrainLoop;
28 
29 struct TrainLoopCallBackData {
TrainLoopCallBackDataTrainLoopCallBackData30   TrainLoopCallBackData(bool train_mode, unsigned int epoch, LiteSession *session, TrainLoop *loop)
31       : train_mode_(train_mode), epoch_(epoch), session_(session), loop_(loop) {}
32 
33   bool train_mode_;       /**< training mode of LiteSession object */
34   unsigned int epoch_;    /**< the current training epoch (starts at 0) */
35   unsigned int step_ = 0; /**< the current step within the epoch */
36   LiteSession *session_;  /**< pointer to the LiteSession */
37   TrainLoop *loop_;
38 };
39 
40 constexpr int RET_CONTINUE = 0;
41 constexpr int RET_STOP_TRAINING = 1;
42 constexpr int RET_EXIT = 2;
43 
44 class TrainLoopCallBack {
45  public:
46   virtual ~TrainLoopCallBack() = default;
47 
48   /// \brief This method is called once before the network executing
49   ///
50   /// \param[in] cb_data info about current execution
Begin(const TrainLoopCallBackData & cb_data)51   virtual void Begin(const TrainLoopCallBackData &cb_data) {}
52 
53   /// \brief This method is called once following the network execution
54   ///
55   /// \param[in] cb_data info about current execution
End(const TrainLoopCallBackData & cb_data)56   virtual void End(const TrainLoopCallBackData &cb_data) {}
57 
58   /// \brief This method is called at the beginning of each epoch
59   ///
60   /// \param[in] cb_data info about current execution
EpochBegin(const TrainLoopCallBackData & cb_data)61   virtual void EpochBegin(const TrainLoopCallBackData &cb_data) {}
62 
63   /// \brief This method is called after the run of each epoch
64   ///
65   /// \param[in] cb_data info about current execution
66   ///
67   /// \return indication if to continue in the train loop:
68   ///         RET_CONTINUE -- continue training
69   ///         RET_STOP_TRAINING -- stop training (e.g., due to achieved accuracy)
70   ///         RET_EXIT -- Exit training (due to error of some sort)
EpochEnd(const TrainLoopCallBackData & cb_data)71   virtual int EpochEnd(const TrainLoopCallBackData &cb_data) { return RET_CONTINUE; }
72 
73   /// \brief This method is called at the beginning of each step
74   ///
75   /// \param[in] cb_data info about current execution
StepBegin(const TrainLoopCallBackData & cb_data)76   virtual void StepBegin(const TrainLoopCallBackData &cb_data) {}
77 
78   /// \brief This method is called after each step is ran
79   ///
80   /// \param[in] cb_data info about current execution
StepEnd(const TrainLoopCallBackData & cb_data)81   virtual void StepEnd(const TrainLoopCallBackData &cb_data) {}
82 };
83 
84 }  // namespace session
85 }  // namespace mindspore
86 #endif  // MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_CALLBACK_H_
87