• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_INCLUDE_API_CALLBACK_CALLBACK_H
17 #define MINDSPORE_INCLUDE_API_CALLBACK_CALLBACK_H
18 
19 #include <cstddef>
20 #include <string>
21 #include <vector>
22 #include <memory>
23 #include "include/api/data_type.h"
24 #include "include/api/dual_abi_helper.h"
25 
26 namespace mindspore {
27 class Model;
28 class ModelImpl;
29 class CallbackImpl;
30 
31 struct TrainCallBackData {
TrainCallBackDataTrainCallBackData32   TrainCallBackData(bool train_mode, int epoch, int step, Model *model): train_mode_(train_mode), epoch_(epoch),
33                     step_(step), model_(model) {}
34 
35   bool train_mode_;       /**< training mode of LiteSession object */
36   unsigned int epoch_;    /**< the current training epoch (starts at 0) */
37   unsigned int step_ = 0; /**< the current step within the epoch */
38   Model *model_;  /**< pointer to the Model object */
39 };
40 
41 enum CallbackRetValue : uint32_t {
42   kContinue = 0,
43   kStopTraining = 1,
44   kExit = 2,
45   kUnknownRetValue = 0xFFFFFFFF
46 };
47 
48 class TrainCallBack {
49  public:
50   virtual ~TrainCallBack() = default;
51 
52   /// \brief This method is called once before the network executing
53   ///
54   /// \param[in] cb_data info about current execution
Begin(const TrainCallBackData & cb_data)55   virtual void Begin(const TrainCallBackData &cb_data) {}
56 
57   /// \brief This method is called once following the network execution
58   ///
59   /// \param[in] cb_data info about current execution
End(const TrainCallBackData & cb_data)60   virtual void End(const TrainCallBackData &cb_data) {}
61 
62   /// \brief This method is called at the beginning of each epoch
63   ///
64   /// \param[in] cb_data info about current execution
EpochBegin(const TrainCallBackData & cb_data)65   virtual void EpochBegin(const TrainCallBackData &cb_data) {}
66 
67   /// \brief This method is called after the run of each epoch
68   ///
69   /// \param[in] cb_data info about current execution
70   ///
71   /// \return indication if to continue in the train loop:
72   ///         RET_CONTINUE -- continue training
73   ///         RET_STOP_TRAINING -- stop training (e.g., due to achieved accuracy)
74   ///         RET_EXIT -- Exit training (due to error of some sort)
EpochEnd(const TrainCallBackData & cb_data)75   virtual CallbackRetValue EpochEnd(const TrainCallBackData &cb_data) { return kContinue; }
76 
77   /// \brief This method is called at the beginning of each step
78   ///
79   /// \param[in] cb_data info about current execution
StepBegin(const TrainCallBackData & cb_data)80   virtual void StepBegin(const TrainCallBackData &cb_data) {}
81 
82   /// \brief This method is called after each step is ran
83   ///
84   /// \param[in] cb_data info about current execution
StepEnd(const TrainCallBackData & cb_data)85   virtual void StepEnd(const TrainCallBackData &cb_data) {}
86 
87  protected:
88   friend class Model;
89   friend class ModelImpl;
90   CallbackImpl* callback_impl_ = nullptr;
91 };
92 
93 }  // namespace mindspore
94 #endif  // MINDSPORE_INCLUDE_API_CALLBACK_CALLBACK_H
95