1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/kernel_compiler/cpu/adam_delta_cpu_kernel.h"
18
19 #include <vector>
20 #include <string>
21 #include <memory>
22
23 #include "backend/kernel_compiler/common_utils.h"
24 #include "backend/kernel_compiler/cpu/nnacl/fp32/adam_fp32.h"
25 #include "runtime/device/cpu/cpu_device_address.h"
26
27 namespace mindspore {
28 namespace kernel {
29 namespace {
30 constexpr size_t kAdamDeltaInputsNum = 9;
31 constexpr size_t kAdamDeltaOutputsNum = 1;
32 } // namespace
33
34 template <typename T>
LaunchAdamDelta(T * delta,T * m,T * v,float lr,float beta1,float beta2,float epsilon,const T * gradient,size_t size)35 void AdamDeltaCPUKernel::LaunchAdamDelta(T *delta, T *m, T *v, float lr, float beta1, float beta2, float epsilon,
36 const T *gradient, size_t size) {
37 std::function<void(size_t, size_t)> task;
38 if (dtype_ == kNumberTypeFloat32) {
39 task = [this, delta, m, v, lr, beta1, beta2, epsilon, gradient](size_t start, size_t end) {
40 (void)AdamDeltaFp32(delta, m, v, lr, beta1, beta2, epsilon, gradient, start, end, use_nesterov_);
41 };
42 } else {
43 task = [this, delta, m, v, lr, beta1, beta2, epsilon, gradient](size_t start, size_t end) {
44 for (size_t c1 = start; c1 < end; ++c1) {
45 m[c1] *= beta1;
46 m[c1] += (1 - beta1) * gradient[c1];
47 v[c1] *= beta2;
48 v[c1] += (1 - beta2) * gradient[c1] * gradient[c1];
49 if (use_nesterov_) {
50 delta[c1] = -lr * (m[c1] * beta1 + (1 - beta1) * gradient[c1]) / (std::sqrt(v[c1]) + epsilon);
51 } else {
52 delta[c1] = -lr * m[c1] / (std::sqrt(v[c1]) + epsilon);
53 }
54 }
55 };
56 }
57 CPUKernelUtils::ParallelFor(task, size);
58 }
59
InitKernel(const CNodePtr & kernel_node)60 void AdamDeltaCPUKernel::InitKernel(const CNodePtr &kernel_node) {
61 MS_EXCEPTION_IF_NULL(kernel_node);
62 kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
63 std::vector<size_t> delta_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
64 std::vector<size_t> m_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
65 std::vector<size_t> v_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
66 std::vector<size_t> grad_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 8);
67 dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
68 if (!IsSameShape(delta_shape, m_shape)) {
69 MS_LOG(EXCEPTION) << "Delta and m should have the same shape";
70 }
71 if (!IsSameShape(delta_shape, v_shape)) {
72 MS_LOG(EXCEPTION) << "Delta and v should have the same shape";
73 }
74 if (!IsSameShape(delta_shape, grad_shape)) {
75 MS_LOG(EXCEPTION) << "Delta and grad should have the same shape";
76 }
77 if (delta_shape.empty()) {
78 MS_LOG(EXCEPTION) << "Delta must be at least 1D";
79 }
80 elem_num_ = 1;
81 for (size_t i = 0; i < delta_shape.size(); ++i) {
82 elem_num_ *= delta_shape[i];
83 }
84 if (elem_num_ < 1) {
85 MS_LOG(EXCEPTION) << "Invalid delta shape";
86 }
87 if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) {
88 use_nesterov_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "use_nesterov");
89 }
90 }
91
CheckParams(const std::vector<kernel::AddressPtr> & inputs,const std::vector<kernel::AddressPtr> & outputs) const92 void AdamDeltaCPUKernel::CheckParams(const std::vector<kernel::AddressPtr> &inputs,
93 const std::vector<kernel::AddressPtr> &outputs) const {
94 CHECK_KERNEL_INPUTS_NUM(inputs.size(), kAdamDeltaInputsNum, kernel_name_);
95 CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kAdamDeltaOutputsNum, kernel_name_);
96
97 size_t elem_size = elem_num_ * 4;
98 std::vector<size_t> expect_sizes = {elem_size, elem_size, 4, 4, 4, 4, 4, 4, elem_size};
99 std::vector<std::string> input_names = {"m", "v", "beta1_power", "beta2_power", "lr",
100 "beta1", "beta2", "epsilon", "grad"};
101 for (size_t i = 0; i < kAdamDeltaInputsNum; ++i) {
102 if (inputs[i]->size != expect_sizes[i]) {
103 MS_LOG(EXCEPTION) << "Error input " << input_names[i] << " size!";
104 }
105 }
106 if (outputs.size() < 1 || outputs[0]->size != elem_size) {
107 MS_LOG(EXCEPTION) << "Error output delta size!";
108 }
109 }
110
Launch(const std::vector<kernel::AddressPtr> & inputs,const std::vector<kernel::AddressPtr> &,const std::vector<kernel::AddressPtr> & outputs)111 bool AdamDeltaCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
112 const std::vector<kernel::AddressPtr> &outputs) {
113 CheckParams(inputs, outputs);
114 auto m = reinterpret_cast<float *>(inputs[0]->addr);
115 auto v = reinterpret_cast<float *>(inputs[1]->addr);
116 auto beta1_power = reinterpret_cast<float *>(inputs[2]->addr)[0];
117 if (beta1_power == 1) {
118 MS_LOG(EXCEPTION) << "The beta1_power should not be 1";
119 }
120 auto beta2_power = reinterpret_cast<float *>(inputs[3]->addr)[0];
121 auto lr = reinterpret_cast<float *>(inputs[4]->addr)[0];
122 auto beta1 = reinterpret_cast<float *>(inputs[5]->addr)[0];
123 auto beta2 = reinterpret_cast<float *>(inputs[6]->addr)[0];
124 auto epsilon = reinterpret_cast<float *>(inputs[7]->addr)[0];
125 auto grad = reinterpret_cast<float *>(inputs[8]->addr);
126 auto delta = reinterpret_cast<float *>(outputs[0]->addr);
127 MS_EXCEPTION_IF_NULL(m);
128 MS_EXCEPTION_IF_NULL(v);
129 MS_EXCEPTION_IF_NULL(grad);
130 MS_EXCEPTION_IF_NULL(delta);
131
132 lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power);
133 // multithreading
134 size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
135 LaunchAdamDelta<float>(delta, m, v, lr, beta1, beta2, epsilon, grad, lens);
136 return true;
137 }
138 } // namespace kernel
139 } // namespace mindspore
140