• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_KERNEL_H_
17 #define MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_KERNEL_H_
18 #include <vector>
19 #include <cmath>
20 #include <cfloat>
21 #include <algorithm>
22 #include <string>
23 #include <atomic>
24 #include <iostream>
25 #include "src/executor/kernel_exec.h"
26 #include "include/errorcode.h"
27 using mindspore::lite::RET_ERROR;
28 using mindspore::lite::RET_OK;
29 using mindspore::lite::RET_OUT_OF_TENSOR_RANGE;
30 
31 namespace mindspore::kernel {
32 constexpr static int kWeightIdx = 0;
33 constexpr static int kMomentVector1stIdx = 1;
34 constexpr static int kMomentVector2stIdx = 2;
35 
36 enum class WeightUpdateMode { NORMAL, VIRTUAL_BATCH, ACCUMULATE_GRADS };
37 
38 class OptimizerKernel : public LiteKernel {
39  public:
40   OptimizerKernel() = default;
OptimizerKernel(OpParameter * parameter,const std::vector<lite::Tensor * > & inputs,const std::vector<lite::Tensor * > & outputs,const lite::InnerContext * ctx,int lr_idx,int grad_idx)41   OptimizerKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
42                   const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, int lr_idx, int grad_idx)
43       : LiteKernel(parameter, inputs, outputs, ctx), lr_idx_(lr_idx), grad_idx_(grad_idx) {}
44   ~OptimizerKernel() = default;
45 
get_optimizer_mode()46   WeightUpdateMode get_optimizer_mode() { return weight_update_mod_; }
47 
Prepare()48   int Prepare() override {
49     default_lr_ = reinterpret_cast<float *>(in_tensors_.at(static_cast<size_t>(lr_idx_))->MutableData())[0];
50     lr_ = default_lr_;
51     return RET_OK;
52   }
53 
SetLearningRate(float lr)54   int SetLearningRate(float lr) {
55     lr_ = lr;
56     return RET_OK;
57   }
58 
GetLearningRate()59   float GetLearningRate() { return lr_; }
60 
GetOptimizerParamsIdxs()61   virtual std::vector<int> GetOptimizerParamsIdxs() const {
62     std::vector<int> indices;
63     return indices;
64   }
65 
GetTrainableParamsIdxs()66   virtual std::vector<int> GetTrainableParamsIdxs() const {
67     std::vector<int> indices;
68     return indices;
69   }
70 
GetOptimizerParams()71   std::vector<lite::Tensor *> GetOptimizerParams() const {
72     std::vector<lite::Tensor *> params;
73     auto indices = GetOptimizerParamsIdxs();
74     indices.push_back(lr_idx_);
75     for (size_t ix = 0; ix < indices.size(); ix++) {
76       auto param = in_tensors_.at(indices[ix]);
77       if (!param->IsConst()) {
78         continue;
79       }
80       params.push_back(param);
81     }
82     return params;
83   }
84 
SetOptimizerParams(lite::Tensor * param)85   bool SetOptimizerParams(lite::Tensor *param) {
86     if (param == nullptr) {
87       return false;
88     }
89     bool found = false;
90     auto indices = GetOptimizerParamsIdxs();
91     indices.push_back(lr_idx_);
92     for (size_t ix = 0; ix < indices.size(); ix++) {
93       if (param->tensor_name() == in_tensors_.at(indices[ix])->tensor_name() && param->ElementsNum() == 1 &&
94           (param->data_type() == kNumberTypeFloat32 || param->data_type() == kNumberTypeFloat)) {
95         auto value = static_cast<float *>(param->MutableData())[0];
96         static_cast<float *>(in_tensors_.at(indices[ix])->MutableData())[0] = value;
97         if (lr_idx_ == indices[ix]) {
98           lr_ = value;
99         }
100         found = true;
101         break;
102       }
103     }
104     return found;
105   }
106 
GetTrainableParams()107   std::vector<lite::Tensor *> GetTrainableParams() const {
108     std::vector<lite::Tensor *> params;
109     auto indices = GetTrainableParamsIdxs();
110     for (size_t ix = 0; ix < indices.size(); ix++) {
111       auto param = in_tensors_.at(indices[ix]);
112       if (!param->IsConst()) {
113         continue;
114       }
115       params.push_back(param);
116     }
117     return params;
118   }
119 
GetGradients()120   lite::Tensor *GetGradients() {
121     lite::Tensor *grad_sum_tensor = nullptr;
122     if (grad_sum_ != nullptr) {
123       auto shape = in_tensors_.at(grad_idx_)->shape();
124       grad_sum_tensor = new (std::nothrow) lite::Tensor(kNumberTypeFloat, shape);
125       if (grad_sum_tensor == nullptr) {
126         MS_LOG(ERROR) << "failed to allocate grad sum tensor";
127         return nullptr;
128       }
129       grad_sum_tensor->set_tensor_name(in_tensors_.at(grad_idx_)->tensor_name());
130       grad_sum_tensor->set_data(static_cast<void *>(grad_sum_));
131       grad_sum_tensor->set_own_data(false);
132     }
133     return grad_sum_tensor;
134   }
135 
RestoreDefaultLearningRate()136   int RestoreDefaultLearningRate() {
137     auto ret = SetLearningRate(default_lr_);
138     return ret;
139   }
140 
SetOptimizerMode(WeightUpdateMode mod)141   int SetOptimizerMode(WeightUpdateMode mod) {
142     if (mod == WeightUpdateMode::VIRTUAL_BATCH || mod == WeightUpdateMode::ACCUMULATE_GRADS) {
143       if (grad_sum_ != nullptr) {
144         ms_context_->allocator->Free(grad_sum_);
145         grad_sum_ = nullptr;
146       }
147       size_t size = in_tensors_.at(grad_idx_)->Size();
148       size_t elem_num = in_tensors_.at(grad_idx_)->ElementsNum();
149       grad_sum_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(size));
150       if (grad_sum_ == nullptr) {
151         MS_LOG(ERROR) << "failed to malloc grad sum tensor, size=" << size;
152         return RET_ERROR;
153       }
154       valid_grad_sum_ = false;
155       std::fill(grad_sum_, grad_sum_ + elem_num, 0);
156       weight_update_mod_ = mod;
157     } else {
158       if (grad_sum_ != nullptr) {
159         auto ret = OptimizerStep();
160         if (ret != RET_OK) {
161           MS_LOG(ERROR) << "OptimizerStep failed.";
162           return RET_ERROR;
163         }
164         ms_context_->allocator->Free(grad_sum_);
165         grad_sum_ = nullptr;
166       }
167     }
168     return RET_OK;
169   }
170 
ExecuteVirtualBatch(int task_id)171   int ExecuteVirtualBatch(int task_id) {
172     auto gradient = reinterpret_cast<float *>(in_tensors_.at(grad_idx_)->MutableData());
173     int length = in_tensors_.at(grad_idx_)->ElementsNum();
174 
175     int stride = UP_DIV(length, ms_context_->thread_num_);
176     int count = MSMIN(stride, length - stride * task_id);
177     int start = stride * task_id;
178     int end = start + count;
179     for (int i = start; i < end; ++i) {
180       grad_sum_[i] += gradient[i];
181     }
182     valid_grad_sum_ = true;
183     return RET_OK;
184   }
185 
OptimizerStep()186   virtual int OptimizerStep() {
187     valid_grad_sum_ = false;
188     return RET_OK;
189   }
190 
Eval()191   int Eval() override {
192     if (weight_update_mod_ != WeightUpdateMode::ACCUMULATE_GRADS) {
193       auto ret = OptimizerStep();
194       if (ret != RET_OK) {
195         MS_LOG(ERROR) << "OptimizerStep failed.";
196         return RET_ERROR;
197       }
198     }
199     return LiteKernel::Eval();
200   }
201 
PreProcess()202   int PreProcess() override {
203     auto ret = LiteKernel::PreProcess();
204     if (ret != RET_OK) {
205       return ret;
206     }
207 
208     auto ctx = static_cast<const lite::InnerContext *>(this->ms_context_);
209     if (ctx->IsCpuFloat16Enabled()) {
210       auto t = in_tensors_.at(grad_idx_);
211       auto gradient = reinterpret_cast<float *>(t->data());
212       int length = static_cast<int>(in_tensors_.at(static_cast<size_t>(grad_idx_))->ElementsNum());
213 
214       for (int i = 0; i < length; ++i) {
215         if (std::isnan(gradient[i]) || std::isinf(gradient[i])) {
216           MS_LOG(INFO) << "optimizer grad is nan or inf";
217           return RET_OUT_OF_TENSOR_RANGE;
218         }
219       }
220 
221       auto is_scale = t->IsScale();
222       auto scale = t->get_scale();
223       if (is_scale) {
224         t->set_scale(1.0f / scale);
225         for (int i = 0; i < length; ++i) {
226           gradient[i] *= (1.0f / scale);
227         }
228       }
229     }
230     return RET_OK;
231   }
set_grad_sum_valid()232   int set_grad_sum_valid() {
233     valid_grad_sum_ = true;
234     return RET_OK;
235   }
236 
237  protected:
238   float default_lr_ = 0.0f;
239   float lr_ = 0.0f;
240   int lr_idx_ = 0;
241   int grad_idx_ = 0;
242   float *grad_sum_ = nullptr;
243   std::atomic_bool valid_grad_sum_ = false;
244 
245  private:
246   WeightUpdateMode weight_update_mod_ = WeightUpdateMode::NORMAL;
247 };
248 
249 }  // namespace mindspore::kernel
250 #endif  // MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_KERNEL_H_
251