1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_H_ 17 #define MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_H_ 18 #include <vector> 19 #include <string> 20 #include <tuple> 21 #include <unordered_map> 22 #include "include/train/train_loop_callback.h" 23 #include "include/train/metrics.h" 24 #include "src/litert/lite_session.h" 25 26 namespace mindspore { 27 class MSTensor; 28 29 namespace dataset { 30 class Dataset; 31 using MSTensorVec = std::vector<mindspore::MSTensor>; 32 } // namespace dataset 33 34 using LoadDataFunc = std::function<int(std::vector<lite::Tensor *> inputs, dataset::MSTensorVec *dataset_vec)>; 35 36 namespace session { 37 38 class TrainLoop { 39 public: 40 /// \brief Static method to create a TrainLoop object 41 /// 42 /// \param[in] train_session Train session object as return from CreateSession\CreateTransferSession API 43 /// 44 /// \return Pointer of MindSpore Lite TrainLoop 45 static lite::TrainLoop *CreateTrainLoop(lite::LiteSession *train_session); 46 47 /// \brief Class destructor 48 virtual ~TrainLoop() = default; 49 50 /// \brief Resets the epoch counter 51 /// 52 /// \return 0 on success or -1 in case of error 53 virtual int Reset() = 0; // resets the epoch counter to 0. 54 55 /// \brief Accessor to the LiteSession 56 /// 57 /// \return pointer of the train_session 58 const virtual lite::LiteSession *train_session() = 0; 59 60 /// \brief Initialize object with metrics 61 /// 62 /// \param[in] verctor of metrics 63 /// 64 /// \return 0 on success or -1 in case of error 65 virtual int Init(std::vector<mindspore::session::Metrics *> metrics) = 0; 66 67 /// \brief Accessor to TrainLoop metric objects 68 /// 69 /// \return vector of metrics 70 virtual std::vector<mindspore::session::Metrics *> GetMetrics() = 0; 71 72 /// \brief Accessor to the Session KernelCallbacks 73 /// 74 /// \param[in] before Define a call_back_function to be called before running each node. 75 /// \param[in] after Define a call_back_function called after running each node. 76 /// 77 /// \return 0 on success or -1 in case of error 78 virtual int SetKernelCallBack(const KernelCallBack &before, const KernelCallBack &after) = 0; 79 80 /// \brief Performs the training Loop 81 /// 82 /// \param[in] epochs The number of epochs to run 83 /// \param[in] dataset Pointer to MindData Dataset object 84 /// \param[in] cbs A vector of TrainLoopCallBack objects 85 /// \param[in] load_func a function that load (and can manipulate) data from Minddata Dataset array into model 86 /// 87 /// \return 0 on success or -1 in case of error 88 virtual int Train(int epochs, mindspore::dataset::Dataset *dataset, std::vector<lite::TrainLoopCallBack *> cbs, 89 LoadDataFunc load_func) = 0; 90 91 /// \brief Performs loop over all data in Eval Mode 92 /// 93 /// \param[in] dataset Pointer to MindData Dataset object 94 /// \param[in] cbs A vector of TrainLoopCallBack objects 95 /// \param[in] load_func a function that load (and can manipulate) data from Minddata Dataset array into model 96 /// \param[in] max_steps (with default = INT_MAX the method iterates all dataset) 97 /// 98 /// \return 0 on success or -1 in case of error 99 virtual int Eval(mindspore::dataset::Dataset *dataset, std::vector<lite::TrainLoopCallBack *> cbs, 100 LoadDataFunc load_func, int max_steps) = 0; 101 }; 102 } // namespace session 103 } // namespace mindspore 104 #endif // MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_H_ 105