1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ 17 #define MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ 18 #include <vector> 19 #include <string> 20 #include <tuple> 21 #include <unordered_map> 22 #include <memory> 23 #include <map> 24 #include "include/train/train_cfg.h" 25 #include "include/train/train_session.h" 26 #include "src/lite_session.h" 27 28 /* 29 Inheritance Diagram 30 31 +-------------------------------+ 32 | session::LiteSession | 33 +--------------↑----------------+ 34 | 35 +--------------+----------------+ 36 | lite::LiteSession | 37 +--------------↑----------------+ 38 | 39 +--------------+----------------+ 40 | lite::TrainSession | 41 +-------------------------------+ 42 */ 43 44 namespace mindspore { 45 namespace lite { 46 using CreatorOp = std::tuple<mindspore::kernel::KernelKey, mindspore::kernel::KernelCreator>; 47 class TrainSession : virtual public lite::LiteSession { 48 public: 49 TrainSession(); 50 ~TrainSession(); 51 52 int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override; 53 54 int CompileGraph(lite::Model *model) override; 55 virtual int CompileTrainGraph(std::shared_ptr<Model> model); 56 57 virtual int Init(InnerContext *context, const TrainCfg *train_cfg); 58 59 int Train() override; 60 int Eval() override; IsTrain()61 bool IsTrain() override { return train_mode_; } IsEval()62 bool IsEval() override { return !train_mode_; } 63 int SetLearningRate(float learning_rate) override; 64 float GetLearningRate() override; 65 std::vector<tensor::MSTensor *> GetGradients() const override; 66 std::vector<tensor::MSTensor *> GetOptimizerParams() const override; 67 int SetOptimizerParams(const std::vector<tensor::MSTensor *> ¶ms) override; 68 int ApplyGradients(const std::vector<tensor::MSTensor *> &gradients) override; 69 int SetupVirtualBatch(int virtual_batch_multiplier, float lr = -1.0f, float momentum = -1.0f) override; 70 BindThread(bool if_bind)71 void BindThread(bool if_bind) override { return lite::LiteSession::BindThread(if_bind); } GetInputs()72 std::vector<tensor::MSTensor *> GetInputs() const override { return lite::LiteSession::GetInputs(); } GetInputsByTensorName(const std::string & tensor_name)73 mindspore::tensor::MSTensor *GetInputsByTensorName(const std::string &tensor_name) const override { 74 return lite::LiteSession::GetInputsByTensorName(tensor_name); 75 } GetOutputsByNodeName(const std::string & node_name)76 std::vector<tensor::MSTensor *> GetOutputsByNodeName(const std::string &node_name) const override { 77 return lite::LiteSession::GetOutputsByNodeName(node_name); 78 } GetOutputs()79 std::unordered_map<std::string, mindspore::tensor::MSTensor *> GetOutputs() const override { 80 return lite::LiteSession::GetOutputs(); 81 } 82 GetOutputTensorNames()83 std::vector<std::string> GetOutputTensorNames() const override { return lite::LiteSession::GetOutputTensorNames(); } GetOutputByTensorName(const std::string & tensor_name)84 mindspore::tensor::MSTensor *GetOutputByTensorName(const std::string &tensor_name) const override { 85 return lite::LiteSession::GetOutputByTensorName(tensor_name); 86 } 87 int Resize(const std::vector<tensor::MSTensor *> &inputs, const std::vector<std::vector<int>> &dims) override; 88 GetPredictions()89 std::vector<tensor::MSTensor *> GetPredictions() const override { 90 std::vector<tensor::MSTensor *> outputs; 91 for (auto it = eval_output_tensor_map_.begin(); it != eval_output_tensor_map_.end(); ++it) { 92 outputs.push_back(it->second); 93 } 94 return outputs; 95 } 96 int Export(const std::string &fb_name, ModelType model_type, QuantizationType quant_type, FormatType, 97 std::vector<std::string> out_put_tensor_name = {}) override; 98 99 std::vector<tensor::MSTensor *> GetFeatureMaps() const override; 100 101 int UpdateFeatureMaps(const std::vector<tensor::MSTensor *> &features_map) override; 102 int FindUseInTensorKernel(std::vector<kernel::LiteKernel *> *use_in_tensor_kernels, 103 const std::vector<lite::Tensor *> &kernel_in_tensors, 104 const std::vector<kernel::LiteKernel *> &inference_kernels); 105 int FindExportKernels(std::vector<kernel::LiteKernel *> *export_kernels, 106 const std::vector<std::string> &export_output_tensor_names, 107 const std::vector<kernel::LiteKernel *> &inference_kernels); 108 109 protected: 110 int AllocWorkSpace(); 111 bool IsLossKernel(const kernel::LiteKernel *kernel) const; 112 bool IsGradKernel(const kernel::LiteKernel *kernel) const; 113 bool IsOptimizer(kernel::LiteKernel *kernel) const; 114 bool IsMaskOutput(kernel::LiteKernel *kernel) const; 115 bool IsBN(kernel::LiteKernel *kernel) const; 116 117 virtual std::vector<CreatorOp> ReplaceOps(); 118 virtual void RestoreOps(const std::vector<CreatorOp> &restore); 119 virtual void CompileTrainKernels(); 120 virtual int CompileInferenceKernels(); 121 virtual void CompileOptimizedKernels(); 122 virtual void CompileTrainOutputs(); 123 virtual void CompileEvalOutputs(); 124 virtual int InitCallBack(); 125 std::shared_ptr<Model> model_ = nullptr; 126 // TrainCfg train_cfg_; 127 std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> orig_output_node_map_; 128 std::unordered_map<std::string, mindspore::tensor::MSTensor *> orig_output_tensor_map_; 129 std::vector<std::string> orig_output_tensor_names_; 130 131 std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> eval_output_node_map_; 132 std::unordered_map<std::string, mindspore::tensor::MSTensor *> eval_output_tensor_map_; 133 std::vector<std::string> eval_output_tensor_names_; 134 135 std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> train_output_node_map_; 136 std::unordered_map<std::string, mindspore::tensor::MSTensor *> train_output_tensor_map_; 137 std::vector<std::string> train_output_tensor_names_; 138 139 std::vector<kernel::LiteKernel *> inference_kernels_; 140 std::vector<kernel::LiteKernel *> train_kernels_; 141 TrainCfg cfg_; 142 143 private: get_loss_name()144 std::string get_loss_name() const { return cfg_.loss_name_; } 145 void BuildInferenceKernelsRecursive(kernel::LiteKernel *ker, std::vector<kernel::LiteKernel *> *req_kernels); 146 int AdminSetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum); 147 int OptimizerStep(); 148 int ExecKernels(const KernelCallBack &before, const KernelCallBack &after, 149 const std::vector<kernel::LiteKernel *> &run_kernel); 150 int MixPrecisionExecKernels(const KernelCallBack &before, const KernelCallBack &after, 151 const std::vector<kernel::LiteKernel *> &run_kernel); 152 int MixPrecisionPreProcess(kernel::LiteKernel *kernel, float scale); 153 int MixPrecisionPostProcess(kernel::LiteKernel *kernel); 154 bool IsLossTensor(Tensor *tensor); 155 void RestoreTensorData(); 156 void FreeRestoreTensors(); 157 bool AllInputsNeedScale(kernel::LiteKernel *kernel); 158 void FreeWorkSpace(); 159 int AllocTensors(const std::vector<kernel::LiteKernel *> &kernels); 160 161 std::map<Tensor *, Tensor *> restored_origin_tensors_; 162 int virtual_batch_idx_ = 0; 163 int virtual_batch_multiplier_ = 0; 164 uint32_t num_of_not_nan_iter_ = 0; 165 void *workspace_ = nullptr; 166 SchedCallBack sched_mix_precision_callback_; 167 bool train_mode_ = false; 168 void *tensors_data_ = nullptr; 169 std::shared_ptr<Allocator> allocator_; 170 }; 171 172 } // namespace lite 173 } // namespace mindspore 174 #endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ 175