1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_H_ 17 #define MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_H_ 18 #include <vector> 19 #include <string> 20 #include <tuple> 21 #include <climits> 22 #include <unordered_map> 23 #include "include/train/train_loop_callback.h" 24 #include "include/train/metrics.h" 25 #include "include/lite_session.h" 26 27 namespace mindspore { 28 class MSTensor; 29 30 namespace dataset { 31 class Dataset; 32 using MSTensorVec = std::vector<mindspore::MSTensor>; 33 } // namespace dataset 34 35 using LoadDataFunc = std::function<int(std::vector<tensor::MSTensor *> inputs, dataset::MSTensorVec *dataset_vec)>; 36 37 namespace session { 38 39 class TrainLoop { 40 public: 41 /// \brief Static method to create a TrainLoop object 42 /// 43 /// \param[in] train_session Train session object as return from CreateSession\CreateTransferSession API 44 /// 45 /// \return Pointer of MindSpore Lite TrainLoop 46 static TrainLoop *CreateTrainLoop(session::LiteSession *train_session); 47 48 /// \brief Class destructor 49 virtual ~TrainLoop() = default; 50 51 /// \brief Resets the epoch counter 52 /// 53 /// \return 0 on success or -1 in case of error 54 virtual int Reset() = 0; // resets the epoch counter to 0. 55 56 /// \brief Accessor to the LiteSession 57 /// 58 /// \return pointer of the train_session 59 const virtual session::LiteSession *train_session() = 0; 60 61 /// \brief Initialize object with metrics 62 /// 63 /// \param[in] verctor of metrics 64 /// 65 /// \return 0 on success or -1 in case of error 66 virtual int Init(std::vector<mindspore::session::Metrics *> metrics) = 0; 67 68 /// \brief Accessor to TrainLoop metric objects 69 /// 70 /// \return vector of metrics 71 virtual std::vector<mindspore::session::Metrics *> GetMetrics() = 0; 72 73 /// \brief Accessor to the Session KernelCallbacks 74 /// 75 /// \param[in] before Define a call_back_function to be called before running each node. 76 /// \param[in] after Define a call_back_function called after running each node. 77 /// 78 /// \return 0 on success or -1 in case of error 79 virtual int SetKernelCallBack(const KernelCallBack &before, const KernelCallBack &after) = 0; 80 81 /// \brief Performs the training Loop 82 /// 83 /// \param[in] epoch The number of epochs to run 84 /// \param[in] dataset Pointer to MindData Dataset object 85 /// \param[in] cbs A vector of TrainLoopCallBack objects 86 /// \param[in] load_func a function that load (and can manipulate) data from Minddata Dataset array into model 87 /// 88 /// \return 0 on success or -1 in case of error 89 virtual int Train(int epochs, mindspore::dataset::Dataset *dataset, std::vector<TrainLoopCallBack *> cbs, 90 LoadDataFunc load_func = nullptr) = 0; 91 92 /// \brief Performs loop over all data in Eval Mode 93 /// 94 /// \param[in] dataset Pointer to MindData Dataset object 95 /// \param[in] cbs A vector of TrainLoopCallBack objects 96 /// \param[in] load_func a function that load (and can manipulate) data from Minddata Dataset array into model 97 /// \param[in] max_steps (with default = INT_MAX the method iterates all dataset) 98 /// 99 /// \return 0 on success or -1 in case of error 100 virtual int Eval(mindspore::dataset::Dataset *dataset, std::vector<TrainLoopCallBack *> cbs, 101 LoadDataFunc load_func = nullptr, int max_steps = INT_MAX) = 0; 102 }; 103 } // namespace session 104 } // namespace mindspore 105 #endif // MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_H_ 106