• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_INCLUDE_API_CALLBACK_CALLBACK_H
17 #define MINDSPORE_INCLUDE_API_CALLBACK_CALLBACK_H
18 
19 #include <cstddef>
20 #include <string>
21 #include <vector>
22 #include <memory>
23 #include <utility>
24 #include "include/api/data_type.h"
25 #include "include/api/dual_abi_helper.h"
26 #include "include/api/types.h"
27 
28 namespace mindspore {
29 class Model;
30 class ModelImpl;
31 class CallbackImpl;
32 
33 using GraphPoint = std::pair<int, float>;
34 
35 struct MS_API TrainCallBackData {
TrainCallBackDataTrainCallBackData36   TrainCallBackData(bool train_mode, int epoch, int step, Model *model)
37       : train_mode_(train_mode), epoch_(epoch), step_(step), model_(model) {}
38 
39   bool train_mode_;       /**< training mode of LiteSession object */
40   unsigned int epoch_;    /**< the current training epoch (starts at 0) */
41   unsigned int step_ = 0; /**< the current step within the epoch */
42   Model *model_;          /**< pointer to the Model object */
43 };
44 
45 enum CallbackRetValue : uint32_t { kContinue = 0, kStopTraining = 1, kExit = 2, kUnknownRetValue = 0xFFFFFFFF };
46 
47 class MS_API TrainCallBack {
48  public:
49   virtual ~TrainCallBack() = default;
50 
51   /// \brief This method is called once before the network executing
52   ///
53   /// \param[in] cb_data info about current execution
Begin(const TrainCallBackData & cb_data)54   virtual void Begin(const TrainCallBackData &cb_data) {}
55 
56   /// \brief This method is called once following the network execution
57   ///
58   /// \param[in] cb_data info about current execution
End(const TrainCallBackData & cb_data)59   virtual void End(const TrainCallBackData &cb_data) {}
60 
61   /// \brief This method is called at the beginning of each epoch
62   ///
63   /// \param[in] cb_data info about current execution
EpochBegin(const TrainCallBackData & cb_data)64   virtual void EpochBegin(const TrainCallBackData &cb_data) {}
65 
66   /// \brief This method is called after the run of each epoch
67   ///
68   /// \param[in] cb_data info about current execution
69   ///
70   /// \return indication if to continue in the train loop:
71   ///         RET_CONTINUE -- continue training
72   ///         RET_STOP_TRAINING -- stop training (e.g., due to achieved accuracy)
73   ///         RET_EXIT -- Exit training (due to error of some sort)
EpochEnd(const TrainCallBackData & cb_data)74   virtual CallbackRetValue EpochEnd(const TrainCallBackData &cb_data) { return kContinue; }
75 
76   /// \brief This method is called at the beginning of each step
77   ///
78   /// \param[in] cb_data info about current execution
StepBegin(const TrainCallBackData & cb_data)79   virtual void StepBegin(const TrainCallBackData &cb_data) {}
80 
81   /// \brief This method is called after each step is ran
82   ///
83   /// \param[in] cb_data info about current execution
StepEnd(const TrainCallBackData & cb_data)84   virtual void StepEnd(const TrainCallBackData &cb_data) {}
85 
86  protected:
87   friend class Model;
88   friend class ModelImpl;
89   CallbackImpl *callback_impl_ = nullptr;
90 };
91 
92 }  // namespace mindspore
93 #endif  // MINDSPORE_INCLUDE_API_CALLBACK_CALLBACK_H
94