1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_KERNEL_H_ 17 #define MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_KERNEL_H_ 18 #include <vector> 19 #include <cmath> 20 #include <cfloat> 21 #include <algorithm> 22 #include <string> 23 #include <atomic> 24 #include <iostream> 25 #include "src/executor/kernel_exec.h" 26 #include "include/errorcode.h" 27 using mindspore::lite::RET_ERROR; 28 using mindspore::lite::RET_OK; 29 using mindspore::lite::RET_OUT_OF_TENSOR_RANGE; 30 31 namespace mindspore::kernel { 32 constexpr static int kWeightIdx = 0; 33 constexpr static int kMomentVector1stIdx = 1; 34 constexpr static int kMomentVector2stIdx = 2; 35 36 enum class WeightUpdateMode { NORMAL, VIRTUAL_BATCH, ACCUMULATE_GRADS }; 37 38 class OptimizerKernel : public LiteKernel { 39 public: 40 OptimizerKernel() = default; OptimizerKernel(OpParameter * parameter,const std::vector<lite::Tensor * > & inputs,const std::vector<lite::Tensor * > & outputs,const lite::InnerContext * ctx,int lr_idx,int grad_idx)41 OptimizerKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 42 const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, int lr_idx, int grad_idx) 43 : LiteKernel(parameter, inputs, outputs, ctx), lr_idx_(lr_idx), grad_idx_(grad_idx) {} 44 ~OptimizerKernel() = default; 45 get_optimizer_mode()46 WeightUpdateMode get_optimizer_mode() { return weight_update_mod_; } 47 Prepare()48 int Prepare() override { 49 default_lr_ = reinterpret_cast<float *>(in_tensors_.at(static_cast<size_t>(lr_idx_))->MutableData())[0]; 50 lr_ = default_lr_; 51 return RET_OK; 52 } 53 SetLearningRate(float lr)54 int SetLearningRate(float lr) { 55 lr_ = lr; 56 return RET_OK; 57 } 58 GetLearningRate()59 float GetLearningRate() { return lr_; } 60 GetOptimizerParamsIdxs()61 virtual std::vector<int> GetOptimizerParamsIdxs() const { 62 std::vector<int> indices; 63 return indices; 64 } 65 GetTrainableParamsIdxs()66 virtual std::vector<int> GetTrainableParamsIdxs() const { 67 std::vector<int> indices; 68 return indices; 69 } 70 GetOptimizerParams()71 std::vector<lite::Tensor *> GetOptimizerParams() const { 72 std::vector<lite::Tensor *> params; 73 auto indices = GetOptimizerParamsIdxs(); 74 indices.push_back(lr_idx_); 75 for (size_t ix = 0; ix < indices.size(); ix++) { 76 auto param = in_tensors_.at(indices[ix]); 77 if (!param->IsConst()) { 78 continue; 79 } 80 params.push_back(param); 81 } 82 return params; 83 } 84 SetOptimizerParams(lite::Tensor * param)85 bool SetOptimizerParams(lite::Tensor *param) { 86 if (param == nullptr) { 87 return false; 88 } 89 bool found = false; 90 auto indices = GetOptimizerParamsIdxs(); 91 indices.push_back(lr_idx_); 92 for (size_t ix = 0; ix < indices.size(); ix++) { 93 if (param->tensor_name() == in_tensors_.at(indices[ix])->tensor_name() && param->ElementsNum() == 1 && 94 (param->data_type() == kNumberTypeFloat32 || param->data_type() == kNumberTypeFloat)) { 95 auto value = static_cast<float *>(param->MutableData())[0]; 96 static_cast<float *>(in_tensors_.at(indices[ix])->MutableData())[0] = value; 97 if (lr_idx_ == indices[ix]) { 98 lr_ = value; 99 } 100 found = true; 101 break; 102 } 103 } 104 return found; 105 } 106 GetTrainableParams()107 std::vector<lite::Tensor *> GetTrainableParams() const { 108 std::vector<lite::Tensor *> params; 109 auto indices = GetTrainableParamsIdxs(); 110 for (size_t ix = 0; ix < indices.size(); ix++) { 111 auto param = in_tensors_.at(indices[ix]); 112 if (!param->IsConst()) { 113 continue; 114 } 115 params.push_back(param); 116 } 117 return params; 118 } 119 GetGradients()120 lite::Tensor *GetGradients() { 121 lite::Tensor *grad_sum_tensor = nullptr; 122 if (grad_sum_ != nullptr) { 123 auto shape = in_tensors_.at(grad_idx_)->shape(); 124 grad_sum_tensor = new (std::nothrow) lite::Tensor(kNumberTypeFloat, shape); 125 if (grad_sum_tensor == nullptr) { 126 MS_LOG(ERROR) << "failed to allocate grad sum tensor"; 127 return nullptr; 128 } 129 grad_sum_tensor->set_tensor_name(in_tensors_.at(grad_idx_)->tensor_name()); 130 grad_sum_tensor->set_data(static_cast<void *>(grad_sum_)); 131 grad_sum_tensor->set_own_data(false); 132 } 133 return grad_sum_tensor; 134 } 135 RestoreDefaultLearningRate()136 int RestoreDefaultLearningRate() { 137 auto ret = SetLearningRate(default_lr_); 138 return ret; 139 } 140 SetOptimizerMode(WeightUpdateMode mod)141 int SetOptimizerMode(WeightUpdateMode mod) { 142 if (mod == WeightUpdateMode::VIRTUAL_BATCH || mod == WeightUpdateMode::ACCUMULATE_GRADS) { 143 if (grad_sum_ != nullptr) { 144 ms_context_->allocator->Free(grad_sum_); 145 grad_sum_ = nullptr; 146 } 147 size_t size = in_tensors_.at(grad_idx_)->Size(); 148 size_t elem_num = in_tensors_.at(grad_idx_)->ElementsNum(); 149 grad_sum_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(size)); 150 if (grad_sum_ == nullptr) { 151 MS_LOG(ERROR) << "failed to malloc grad sum tensor, size=" << size; 152 return RET_ERROR; 153 } 154 valid_grad_sum_ = false; 155 std::fill(grad_sum_, grad_sum_ + elem_num, 0); 156 weight_update_mod_ = mod; 157 } else { 158 if (grad_sum_ != nullptr) { 159 auto ret = OptimizerStep(); 160 if (ret != RET_OK) { 161 MS_LOG(ERROR) << "OptimizerStep failed."; 162 return RET_ERROR; 163 } 164 ms_context_->allocator->Free(grad_sum_); 165 grad_sum_ = nullptr; 166 } 167 } 168 return RET_OK; 169 } 170 ExecuteVirtualBatch(int task_id)171 int ExecuteVirtualBatch(int task_id) { 172 auto gradient = reinterpret_cast<float *>(in_tensors_.at(grad_idx_)->MutableData()); 173 int length = in_tensors_.at(grad_idx_)->ElementsNum(); 174 175 int stride = UP_DIV(length, ms_context_->thread_num_); 176 int count = MSMIN(stride, length - stride * task_id); 177 int start = stride * task_id; 178 int end = start + count; 179 for (int i = start; i < end; ++i) { 180 grad_sum_[i] += gradient[i]; 181 } 182 valid_grad_sum_ = true; 183 return RET_OK; 184 } 185 OptimizerStep()186 virtual int OptimizerStep() { 187 valid_grad_sum_ = false; 188 return RET_OK; 189 } 190 Eval()191 int Eval() override { 192 if (weight_update_mod_ != WeightUpdateMode::ACCUMULATE_GRADS) { 193 auto ret = OptimizerStep(); 194 if (ret != RET_OK) { 195 MS_LOG(ERROR) << "OptimizerStep failed."; 196 return RET_ERROR; 197 } 198 } 199 return LiteKernel::Eval(); 200 } 201 PreProcess()202 int PreProcess() override { 203 auto ret = LiteKernel::PreProcess(); 204 if (ret != RET_OK) { 205 return ret; 206 } 207 208 auto ctx = static_cast<const lite::InnerContext *>(this->ms_context_); 209 if (ctx->IsCpuFloat16Enabled()) { 210 auto t = in_tensors_.at(grad_idx_); 211 auto gradient = reinterpret_cast<float *>(t->data()); 212 int length = static_cast<int>(in_tensors_.at(static_cast<size_t>(grad_idx_))->ElementsNum()); 213 214 for (int i = 0; i < length; ++i) { 215 if (std::isnan(gradient[i]) || std::isinf(gradient[i])) { 216 MS_LOG(INFO) << "optimizer grad is nan or inf"; 217 return RET_OUT_OF_TENSOR_RANGE; 218 } 219 } 220 221 auto is_scale = t->IsScale(); 222 auto scale = t->get_scale(); 223 if (is_scale) { 224 t->set_scale(1.0f / scale); 225 for (int i = 0; i < length; ++i) { 226 gradient[i] *= (1.0f / scale); 227 } 228 } 229 } 230 return RET_OK; 231 } set_grad_sum_valid()232 int set_grad_sum_valid() { 233 valid_grad_sum_ = true; 234 return RET_OK; 235 } 236 237 protected: 238 float default_lr_ = 0.0f; 239 float lr_ = 0.0f; 240 int lr_idx_ = 0; 241 int grad_idx_ = 0; 242 float *grad_sum_ = nullptr; 243 std::atomic_bool valid_grad_sum_ = false; 244 245 private: 246 WeightUpdateMode weight_update_mod_ = WeightUpdateMode::NORMAL; 247 }; 248 249 } // namespace mindspore::kernel 250 #endif // MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_KERNEL_H_ 251