1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_SRC_CXX_API_CALLBACK_CALLBACK_ADAPTER_H_ 18 #define MINDSPORE_LITE_SRC_CXX_API_CALLBACK_CALLBACK_ADAPTER_H_ 19 20 #include <functional> 21 #include <map> 22 #include <string> 23 #include <vector> 24 #include <memory> 25 #include <utility> 26 #include <unordered_map> 27 #include "include/api/model.h" 28 #include "include/api/context.h" 29 #include "include/api/cell.h" 30 #include "include/lite_session.h" 31 #include "include/train/train_loop_callback.h" 32 33 namespace mindspore { 34 35 class TrainLoopCallBackAdapter : public session::TrainLoopCallBack { 36 public: TrainLoopCallBackAdapter(Model * model,TrainCallBack * call_back)37 explicit TrainLoopCallBackAdapter(Model *model, TrainCallBack *call_back) : model_(model), call_back_(call_back) {} 38 TrainLoopCallBackAdapter() = delete; 39 Begin(const session::TrainLoopCallBackData & i_cb_data)40 void Begin(const session::TrainLoopCallBackData &i_cb_data) override { 41 call_back_->Begin(TrainCallBackData(i_cb_data.train_mode_, i_cb_data.epoch_, i_cb_data.step_, model_)); 42 }; 43 End(const session::TrainLoopCallBackData & i_cb_data)44 void End(const session::TrainLoopCallBackData &i_cb_data) override { 45 call_back_->End(TrainCallBackData(i_cb_data.train_mode_, i_cb_data.epoch_, i_cb_data.step_, model_)); 46 }; 47 EpochBegin(const session::TrainLoopCallBackData & i_cb_data)48 void EpochBegin(const session::TrainLoopCallBackData &i_cb_data) override { 49 call_back_->EpochBegin(TrainCallBackData(i_cb_data.train_mode_, i_cb_data.epoch_, i_cb_data.step_, model_)); 50 }; 51 EpochEnd(const session::TrainLoopCallBackData & i_cb_data)52 int EpochEnd(const session::TrainLoopCallBackData &i_cb_data) override { 53 return call_back_->EpochEnd(TrainCallBackData(i_cb_data.train_mode_, i_cb_data.epoch_, i_cb_data.step_, model_)); 54 }; 55 StepBegin(const session::TrainLoopCallBackData & i_cb_data)56 void StepBegin(const session::TrainLoopCallBackData &i_cb_data) override { 57 call_back_->StepBegin(TrainCallBackData(i_cb_data.train_mode_, i_cb_data.epoch_, i_cb_data.step_, model_)); 58 }; 59 StepEnd(const session::TrainLoopCallBackData & i_cb_data)60 void StepEnd(const session::TrainLoopCallBackData &i_cb_data) override { 61 call_back_->StepEnd(TrainCallBackData(i_cb_data.train_mode_, i_cb_data.epoch_, i_cb_data.step_, model_)); 62 }; 63 64 private: 65 Model *model_; 66 TrainCallBack *call_back_; 67 }; 68 } // namespace mindspore 69 70 #endif // MINDSPORE_LITE_SRC_CXX_API_CALLBACK_CALLBACK_ADAPTER_H_ 71