• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_KERNEL_H_
17 #define MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_KERNEL_H_
18 #include <vector>
19 #include <cmath>
20 #include <cfloat>
21 #include <algorithm>
22 #include <string>
23 #include <iostream>
24 #include "src/lite_kernel.h"
25 #include "include/ms_tensor.h"
26 #include "include/errorcode.h"
27 using mindspore::lite::RET_ERROR;
28 using mindspore::lite::RET_OK;
29 using mindspore::lite::RET_OUT_OF_TENSOR_RANGE;
30 
MS_ISNAN(float var)31 static __attribute__((always_inline)) inline bool MS_ISNAN(float var) {
32   volatile float d = var;
33   return d != d;
34 }
35 
36 namespace mindspore::kernel {
37 enum class WeightUpdateMode { NORMAL, VIRTUAL_BATCH, ACCUMULATE_GRADS };
38 
39 class OptimizerKernel : public InnerKernel {
40  public:
41   OptimizerKernel() = default;
OptimizerKernel(OpParameter * parameter,const std::vector<lite::Tensor * > & inputs,const std::vector<lite::Tensor * > & outputs,const lite::InnerContext * ctx,int lr_idx,int grad_idx)42   OptimizerKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
43                   const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, int lr_idx, int grad_idx)
44       : InnerKernel(parameter, inputs, outputs, ctx), lr_idx_(lr_idx), grad_idx_(grad_idx) {}
45   ~OptimizerKernel() = default;
46 
get_optimizer_mode()47   WeightUpdateMode get_optimizer_mode() { return weight_update_mod_; }
48 
Init()49   int Init() override {
50     default_lr_ = reinterpret_cast<float *>(in_tensors_.at(lr_idx_)->MutableData())[0];
51     lr_ = default_lr_;
52     return RET_OK;
53   }
54 
SetLearningRate(float lr)55   int SetLearningRate(float lr) {
56     lr_ = lr;
57     return RET_OK;
58   }
59 
GetLearningRate()60   float GetLearningRate() { return lr_; }
61 
GetOptimizerParamsIdxs()62   virtual std::vector<int> GetOptimizerParamsIdxs() const {
63     std::vector<int> indices;
64     return indices;
65   }
66 
GetOptimizerParams()67   std::vector<tensor::MSTensor *> GetOptimizerParams() const {
68     std::vector<tensor::MSTensor *> params;
69     auto indices = GetOptimizerParamsIdxs();
70     indices.push_back(lr_idx_);
71     for (size_t ix = 0; ix < indices.size(); ix++) {
72       auto param = in_tensors_.at(indices[ix]);
73       if (param->data() == nullptr) {
74         continue;
75       }
76       params.push_back(param);
77     }
78     return params;
79   }
80 
SetOptimizerParams(tensor::MSTensor * param)81   bool SetOptimizerParams(tensor::MSTensor *param) {
82     if (param == nullptr) {
83       return false;
84     }
85     bool found = false;
86     auto indices = GetOptimizerParamsIdxs();
87     indices.push_back(lr_idx_);
88     for (size_t ix = 0; ix < indices.size(); ix++) {
89       if (param->tensor_name() == in_tensors_.at(indices[ix])->tensor_name() && param->ElementsNum() == 1 &&
90           (param->data_type() == kNumberTypeFloat32 || param->data_type() == kNumberTypeFloat)) {
91         auto value = static_cast<float *>(param->MutableData())[0];
92         static_cast<float *>(in_tensors_.at(indices[ix])->MutableData())[0] = value;
93         if (lr_idx_ == indices[ix]) {
94           lr_ = value;
95         }
96         found = true;
97         break;
98       }
99     }
100     return found;
101   }
102 
GetGradients()103   lite::Tensor *GetGradients() {
104     lite::Tensor *grad_sum_tensor = nullptr;
105     if (grad_sum_ != nullptr) {
106       auto shape = in_tensors_.at(grad_idx_)->shape();
107       grad_sum_tensor = new (std::nothrow) lite::Tensor(kNumberTypeFloat, shape);
108       if (grad_sum_tensor == nullptr) {
109         MS_LOG(ERROR) << "failed to allocate grad sum tensor";
110         return nullptr;
111       }
112       grad_sum_tensor->set_tensor_name(in_tensors_.at(grad_idx_)->tensor_name());
113       grad_sum_tensor->set_data(static_cast<void *>(grad_sum_));
114       grad_sum_tensor->set_own_data(false);
115     }
116     return grad_sum_tensor;
117   }
118 
RestoreDefaultLearningRate()119   int RestoreDefaultLearningRate() {
120     SetLearningRate(default_lr_);
121     return RET_OK;
122   }
123 
SetOptimizerMode(WeightUpdateMode mod)124   int SetOptimizerMode(WeightUpdateMode mod) {
125     if (mod == WeightUpdateMode::VIRTUAL_BATCH || mod == WeightUpdateMode::ACCUMULATE_GRADS) {
126       if (grad_sum_ != nullptr) {
127         ms_context_->allocator->Free(grad_sum_);
128         grad_sum_ = nullptr;
129       }
130       size_t size = in_tensors_.at(grad_idx_)->Size();
131       size_t elem_num = in_tensors_.at(grad_idx_)->ElementsNum();
132       grad_sum_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(size));
133       if (grad_sum_ == nullptr) {
134         MS_LOG(ERROR) << "failed to malloc grad sum tensor, size=" << size;
135         return RET_ERROR;
136       }
137       valid_grad_sum_ = false;
138       std::fill(grad_sum_, grad_sum_ + elem_num, 0);
139       weight_update_mod_ = mod;
140     } else {
141       if (grad_sum_ != nullptr) {
142         OptimizerStep();
143         ms_context_->allocator->Free(grad_sum_);
144         grad_sum_ = nullptr;
145       }
146     }
147     return RET_OK;
148   }
149 
ExecuteVirtualBatch(int task_id)150   int ExecuteVirtualBatch(int task_id) {
151     auto gradient = reinterpret_cast<float *>(in_tensors_.at(grad_idx_)->MutableData());
152     int length = in_tensors_.at(grad_idx_)->ElementsNum();
153 
154     int stride = UP_DIV(length, ms_context_->thread_num_);
155     int count = MSMIN(stride, length - stride * task_id);
156     int start = stride * task_id;
157     int end = start + count;
158     for (int i = start; i < end; ++i) {
159       grad_sum_[i] += gradient[i];
160     }
161     valid_grad_sum_ = true;
162     return RET_OK;
163   }
164 
OptimizerStep()165   virtual int OptimizerStep() {
166     valid_grad_sum_ = false;
167     return RET_OK;
168   }
169 
Eval()170   int Eval() override {
171     if (weight_update_mod_ != WeightUpdateMode::ACCUMULATE_GRADS) {
172       OptimizerStep();
173     }
174     return InnerKernel::Eval();
175   }
176 
PreProcess()177   int PreProcess() override {
178     auto ret = InnerKernel::PreProcess();
179     if (ret != RET_OK) {
180       return ret;
181     }
182 
183     auto ctx = static_cast<const lite::InnerContext *>(this->ms_context_);
184     if (ctx->IsCpuFloat16Enabled()) {
185       auto t = in_tensors_.at(grad_idx_);
186       auto gradient = reinterpret_cast<float *>(t->data());
187       int length = in_tensors_.at(grad_idx_)->ElementsNum();
188 
189       for (int i = 0; i < length; ++i) {
190         if (MS_ISNAN(gradient[i]) || std::isinf(gradient[i])) {
191           MS_LOG(INFO) << "optimizer grad is nan or inf";
192           return RET_OUT_OF_TENSOR_RANGE;
193         }
194       }
195 
196       auto is_scale = t->IsScale();
197       auto scale = t->get_scale();
198       if (is_scale) {
199         t->set_scale(1.0f / scale);
200         for (int i = 0; i < length; ++i) {
201           gradient[i] *= (1.0f / scale);
202         }
203       }
204     }
205     return RET_OK;
206   }
set_grad_sum_valid()207   int set_grad_sum_valid() {
208     valid_grad_sum_ = true;
209     return RET_OK;
210   }
211 
212  protected:
213   float default_lr_ = 0.0f;
214   float lr_ = 0.0f;
215   int lr_idx_ = 0;
216   int grad_idx_ = 0;
217   float *grad_sum_ = nullptr;
218   bool valid_grad_sum_ = false;
219 
220  private:
221   WeightUpdateMode weight_update_mod_ = WeightUpdateMode::NORMAL;
222 };
223 
224 }  // namespace mindspore::kernel
225 #endif  // MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_KERNEL_H_
226