1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ 17 #define MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ 18 #include <vector> 19 #include <string> 20 #include <tuple> 21 #include <unordered_map> 22 #include <memory> 23 #include <map> 24 #include "include/train/train_cfg.h" 25 #include "src/litert/lite_session.h" 26 27 /* 28 Inheritance Diagram 29 30 +--------------+----------------+ 31 | lite::LiteSession | 32 +--------------↑----------------+ 33 | 34 +--------------+----------------+ 35 | lite::TrainSession | 36 +-------------------------------+ 37 */ 38 39 #define TRAIN_SESSION_CHECK_FALSE_MSG(value, errcode, msg) \ 40 do { \ 41 if ((value)) { \ 42 MS_LOG(ERROR) << #msg; \ 43 return errcode; \ 44 } \ 45 } while (0) 46 47 namespace mindspore { 48 namespace lite { 49 using CreatorOp = std::tuple<mindspore::kernel::KernelKey, mindspore::kernel::KernelCreator>; 50 class TrainSession : virtual public lite::LiteSession { 51 public: 52 TrainSession(); 53 ~TrainSession(); 54 int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override; 55 56 int CompileGraph(lite::Model *model) override; 57 virtual int CompileTrainGraph(std::shared_ptr<Model> model); 58 59 virtual int TrainInit(const std::shared_ptr<InnerContext> &context, const TrainCfg *train_cfg); 60 61 int Train() override; 62 int Eval() override; IsTrain()63 bool IsTrain() override { return train_mode_; } IsEval()64 bool IsEval() override { return !train_mode_; } 65 int SetLearningRate(float learning_rate) override; 66 float GetLearningRate() override; 67 std::vector<lite::Tensor *> GetGradients() const override; 68 std::vector<lite::Tensor *> GetOptimizerParams() const override; 69 int SetOptimizerParams(const std::vector<lite::Tensor *> ¶ms) override; 70 int ApplyGradients(const std::vector<lite::Tensor *> &gradients) override; 71 int SetupVirtualBatch(int virtual_batch_multiplier, float lr = -1.0f, float momentum = -1.0f) override; 72 BindThread(bool if_bind)73 void BindThread(bool if_bind) override { return lite::LiteSession::BindThread(if_bind); } GetInputs()74 std::vector<lite::Tensor *> GetInputs() const override { return lite::LiteSession::GetInputs(); } GetInputsByTensorName(const std::string & tensor_name)75 mindspore::lite::Tensor *GetInputsByTensorName(const std::string &tensor_name) const override { 76 return lite::LiteSession::GetInputsByTensorName(tensor_name); 77 } GetOutputsByNodeName(const std::string & node_name)78 std::vector<lite::Tensor *> GetOutputsByNodeName(const std::string &node_name) const override { 79 return lite::LiteSession::GetOutputsByNodeName(node_name); 80 } GetOutputs()81 std::unordered_map<std::string, mindspore::lite::Tensor *> GetOutputs() const override { 82 return lite::LiteSession::GetOutputs(); 83 } 84 GetOutputTensorNames()85 std::vector<std::string> GetOutputTensorNames() const override { return lite::LiteSession::GetOutputTensorNames(); } GetOutputByTensorName(const std::string & tensor_name)86 mindspore::lite::Tensor *GetOutputByTensorName(const std::string &tensor_name) const override { 87 return lite::LiteSession::GetOutputByTensorName(tensor_name); 88 } 89 int Resize(const std::vector<lite::Tensor *> &inputs, const std::vector<std::vector<int>> &dims) override; 90 GetPredictions()91 std::vector<lite::Tensor *> GetPredictions() const override { 92 std::vector<lite::Tensor *> outputs; 93 for (auto it = eval_output_tensor_map_.begin(); it != eval_output_tensor_map_.end(); ++it) { 94 outputs.push_back(it->second); 95 } 96 return outputs; 97 } 98 int Export(const std::string &fb_name, ModelType model_type, QuantizationType quant_type, FormatType, 99 std::vector<std::string> out_put_tensor_name = {}) override; 100 int Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType, 101 std::vector<std::string> out_put_tensor_name = {}) override; 102 int ExportWeightsCollaborateWithMicro(const std::string &file_name, lite::ModelType model_type, FormatType, 103 bool enable_fp16, 104 const std::vector<std::string> &changeable_weights_name) override; 105 std::vector<lite::Tensor *> GetFeatureMaps() const override; 106 std::vector<lite::Tensor *> GetTrainableParams() const override; 107 108 int UpdateFeatureMaps(const std::vector<lite::Tensor *> &features_map) override; 109 int FindUseInTensorKernel(std::vector<kernel::KernelExec *> *use_in_tensor_kernels, 110 const std::vector<lite::Tensor *> &kernel_in_tensors, 111 const std::vector<kernel::KernelExec *> &inference_kernels); 112 int FindExportKernels(std::vector<kernel::KernelExec *> *export_kernels, 113 const std::vector<std::string> &export_output_tensor_names, 114 const std::vector<kernel::KernelExec *> &inference_kernels); 115 116 protected: 117 int AllocWorkSpace(); 118 bool IsLossKernel(const kernel::KernelExec *kernel) const; 119 bool IsLossInKernel(const kernel::KernelExec *kernel) const; 120 bool IsGradKernel(const kernel::KernelExec *kernel) const; 121 bool IsOptimizer(kernel::KernelExec *kernel) const; 122 bool IsMaskOutput(kernel::KernelExec *kernel) const; 123 bool IsBN(kernel::KernelExec *kernel) const; 124 125 virtual std::vector<CreatorOp> ReplaceOps(); 126 virtual void RestoreOps(const std::vector<CreatorOp> &restore); 127 virtual void CompileTrainKernels(); 128 virtual int CompileInferenceKernels(); 129 virtual void CompileOptimizedKernels(); 130 virtual void CompileTrainableParams(); 131 virtual int CompileConstFoldedKernels(); 132 virtual void CompileTrainOutputs(); 133 virtual void CompileEvalOutputs(); 134 virtual int InitCallBack(); 135 virtual int FindConstFoldedKernels(); 136 std::shared_ptr<Model> model_ = nullptr; 137 std::unordered_map<std::string, std::vector<mindspore::lite::Tensor *>> orig_output_node_map_; 138 std::unordered_map<std::string, mindspore::lite::Tensor *> orig_output_tensor_map_; 139 std::vector<std::string> orig_output_tensor_names_; 140 141 std::unordered_map<std::string, std::vector<mindspore::lite::Tensor *>> eval_output_node_map_; 142 std::unordered_map<std::string, mindspore::lite::Tensor *> eval_output_tensor_map_; 143 std::vector<std::string> eval_output_tensor_names_; 144 145 std::unordered_map<std::string, std::vector<mindspore::lite::Tensor *>> train_output_node_map_; 146 std::unordered_map<std::string, mindspore::lite::Tensor *> train_output_tensor_map_; 147 std::vector<std::string> train_output_tensor_names_; 148 149 std::vector<kernel::KernelExec *> inference_kernels_; 150 std::vector<kernel::KernelExec *> train_kernels_; 151 std::vector<kernel::KernelExec *> const_fold_kernels_; 152 std::vector<lite::Tensor *> const_output_tensors_; 153 TrainCfg cfg_; 154 155 private: get_loss_name()156 std::vector<std::string> get_loss_name() const { return cfg_.loss_name_; } 157 void BuildInferenceKernelsRecursive(kernel::KernelExec *ker, std::vector<kernel::KernelExec *> *req_kernels); 158 int AdminSetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum); 159 int OptimizerStep(); 160 int ExecKernels(const KernelCallBack &before, const KernelCallBack &after, 161 const std::vector<kernel::KernelExec *> &run_kernel); 162 int MixPrecisionExecKernels(const KernelCallBack &before, const KernelCallBack &after, 163 const std::vector<kernel::KernelExec *> &run_kernel); 164 int MixPrecisionPreProcess(kernel::KernelExec *kernel, float scale); 165 int MixPrecisionPostProcess(kernel::KernelExec *kernel); 166 bool IsLossTensor(Tensor *tensor); 167 void RestoreTensorData(); 168 void FreeRestoreTensors(); 169 bool AllInputsNeedScale(kernel::KernelExec *kernel); 170 void FreeWorkSpace(); 171 int AllocTensors(const std::vector<kernel::KernelExec *> &kernels); 172 bool IsInPlaceKernel(kernel::KernelExec *kernel); 173 bool IsInPlaceTensor(kernel::KernelExec *kernel, uint32_t idx, 174 const std::unordered_map<lite::Tensor *, int> &ref_count, uint32_t *input_idx); 175 size_t GetInplaceTensorOffset(kernel::KernelExec *kernel, 176 const std::unordered_map<lite::Tensor *, size_t> &offset_map, 177 std::unordered_map<lite::Tensor *, int> *ref_count, uint32_t input_idx); 178 template <typename DestType> 179 int ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type, FormatType, 180 std::vector<std::string> out_put_tensor_name = {}); 181 lite::Tensor *FindObfTensor(); 182 int ChangeObfWeight(std::string tensor_name, float obf_ratio); 183 float ModelRecoverObfuscate(bool change_weight = true); 184 int ModelDeObfuscate(float obf_ratio); 185 std::map<Tensor *, Tensor *> restored_origin_tensors_; 186 std::vector<Tensor *> trainable_parameters_; 187 int virtual_batch_idx_ = 0; 188 int virtual_batch_multiplier_ = 0; 189 uint32_t num_of_not_nan_iter_ = 0; 190 void *workspace_ = nullptr; 191 SchedCallBack sched_mix_precision_callback_; 192 bool train_mode_ = false; 193 bool model_buff_changed_ = false; 194 void *tensors_data_ = nullptr; 195 size_t tensors_data_size_ = 0; 196 std::shared_ptr<Allocator> allocator_; 197 }; 198 199 } // namespace lite 200 } // namespace mindspore 201 #endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ 202