• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/kernel_compiler/cpu/adam_delta_cpu_kernel.h"
18 
19 #include <vector>
20 #include <string>
21 #include <memory>
22 
23 #include "backend/kernel_compiler/common_utils.h"
24 #include "backend/kernel_compiler/cpu/nnacl/fp32/adam_fp32.h"
25 #include "runtime/device/cpu/cpu_device_address.h"
26 
27 namespace mindspore {
28 namespace kernel {
29 namespace {
30 constexpr size_t kAdamDeltaInputsNum = 9;
31 constexpr size_t kAdamDeltaOutputsNum = 1;
32 }  // namespace
33 
34 template <typename T>
LaunchAdamDelta(T * delta,T * m,T * v,float lr,float beta1,float beta2,float epsilon,const T * gradient,size_t size)35 void AdamDeltaCPUKernel::LaunchAdamDelta(T *delta, T *m, T *v, float lr, float beta1, float beta2, float epsilon,
36                                          const T *gradient, size_t size) {
37   std::function<void(size_t, size_t)> task;
38   if (dtype_ == kNumberTypeFloat32) {
39     task = [this, delta, m, v, lr, beta1, beta2, epsilon, gradient](size_t start, size_t end) {
40       (void)AdamDeltaFp32(delta, m, v, lr, beta1, beta2, epsilon, gradient, start, end, use_nesterov_);
41     };
42   } else {
43     task = [this, delta, m, v, lr, beta1, beta2, epsilon, gradient](size_t start, size_t end) {
44       for (size_t c1 = start; c1 < end; ++c1) {
45         m[c1] *= beta1;
46         m[c1] += (1 - beta1) * gradient[c1];
47         v[c1] *= beta2;
48         v[c1] += (1 - beta2) * gradient[c1] * gradient[c1];
49         if (use_nesterov_) {
50           delta[c1] = -lr * (m[c1] * beta1 + (1 - beta1) * gradient[c1]) / (std::sqrt(v[c1]) + epsilon);
51         } else {
52           delta[c1] = -lr * m[c1] / (std::sqrt(v[c1]) + epsilon);
53         }
54       }
55     };
56   }
57   CPUKernelUtils::ParallelFor(task, size);
58 }
59 
InitKernel(const CNodePtr & kernel_node)60 void AdamDeltaCPUKernel::InitKernel(const CNodePtr &kernel_node) {
61   MS_EXCEPTION_IF_NULL(kernel_node);
62   kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
63   std::vector<size_t> delta_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
64   std::vector<size_t> m_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
65   std::vector<size_t> v_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
66   std::vector<size_t> grad_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 8);
67   dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
68   if (!IsSameShape(delta_shape, m_shape)) {
69     MS_LOG(EXCEPTION) << "Delta and m should have the same shape";
70   }
71   if (!IsSameShape(delta_shape, v_shape)) {
72     MS_LOG(EXCEPTION) << "Delta and v should have the same shape";
73   }
74   if (!IsSameShape(delta_shape, grad_shape)) {
75     MS_LOG(EXCEPTION) << "Delta and grad should have the same shape";
76   }
77   if (delta_shape.empty()) {
78     MS_LOG(EXCEPTION) << "Delta must be at least 1D";
79   }
80   elem_num_ = 1;
81   for (size_t i = 0; i < delta_shape.size(); ++i) {
82     elem_num_ *= delta_shape[i];
83   }
84   if (elem_num_ < 1) {
85     MS_LOG(EXCEPTION) << "Invalid delta shape";
86   }
87   if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) {
88     use_nesterov_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "use_nesterov");
89   }
90 }
91 
CheckParams(const std::vector<kernel::AddressPtr> & inputs,const std::vector<kernel::AddressPtr> & outputs) const92 void AdamDeltaCPUKernel::CheckParams(const std::vector<kernel::AddressPtr> &inputs,
93                                      const std::vector<kernel::AddressPtr> &outputs) const {
94   CHECK_KERNEL_INPUTS_NUM(inputs.size(), kAdamDeltaInputsNum, kernel_name_);
95   CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kAdamDeltaOutputsNum, kernel_name_);
96 
97   size_t elem_size = elem_num_ * 4;
98   std::vector<size_t> expect_sizes = {elem_size, elem_size, 4, 4, 4, 4, 4, 4, elem_size};
99   std::vector<std::string> input_names = {"m",     "v",     "beta1_power", "beta2_power", "lr",
100                                           "beta1", "beta2", "epsilon",     "grad"};
101   for (size_t i = 0; i < kAdamDeltaInputsNum; ++i) {
102     if (inputs[i]->size != expect_sizes[i]) {
103       MS_LOG(EXCEPTION) << "Error input " << input_names[i] << " size!";
104     }
105   }
106   if (outputs.size() < 1 || outputs[0]->size != elem_size) {
107     MS_LOG(EXCEPTION) << "Error output delta size!";
108   }
109 }
110 
Launch(const std::vector<kernel::AddressPtr> & inputs,const std::vector<kernel::AddressPtr> &,const std::vector<kernel::AddressPtr> & outputs)111 bool AdamDeltaCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
112                                 const std::vector<kernel::AddressPtr> &outputs) {
113   CheckParams(inputs, outputs);
114   auto m = reinterpret_cast<float *>(inputs[0]->addr);
115   auto v = reinterpret_cast<float *>(inputs[1]->addr);
116   auto beta1_power = reinterpret_cast<float *>(inputs[2]->addr)[0];
117   if (beta1_power == 1) {
118     MS_LOG(EXCEPTION) << "The beta1_power should not be 1";
119   }
120   auto beta2_power = reinterpret_cast<float *>(inputs[3]->addr)[0];
121   auto lr = reinterpret_cast<float *>(inputs[4]->addr)[0];
122   auto beta1 = reinterpret_cast<float *>(inputs[5]->addr)[0];
123   auto beta2 = reinterpret_cast<float *>(inputs[6]->addr)[0];
124   auto epsilon = reinterpret_cast<float *>(inputs[7]->addr)[0];
125   auto grad = reinterpret_cast<float *>(inputs[8]->addr);
126   auto delta = reinterpret_cast<float *>(outputs[0]->addr);
127   MS_EXCEPTION_IF_NULL(m);
128   MS_EXCEPTION_IF_NULL(v);
129   MS_EXCEPTION_IF_NULL(grad);
130   MS_EXCEPTION_IF_NULL(delta);
131 
132   lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power);
133   // multithreading
134   size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
135   LaunchAdamDelta<float>(delta, m, v, lr, beta1, beta2, epsilon, grad, lens);
136   return true;
137 }
138 }  // namespace kernel
139 }  // namespace mindspore
140