1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <math.h>
17 #include "nnacl/fp32/exp_fp32.h"
18 #include "nnacl/fp32/adam_fp32.h"
19
AdamFp32(float * var,float * m,float * v,float lr,float beta1,float beta2,float epsilon,const float * gradient,size_t start,size_t end,bool use_nesterov)20 int AdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, const float *gradient,
21 size_t start, size_t end, bool use_nesterov) {
22 size_t c1 = start;
23 #ifdef ENABLE_AVX
24 size_t c8 = ((end - start) / C8NUM) * C8NUM;
25 __m256 coeff1_r = _mm256_set1_ps(1 - beta1);
26 __m256 coeff2_r = _mm256_set1_ps(1 - beta2);
27 __m256 beta1_r = _mm256_set1_ps(beta1);
28 __m256 lr_r = _mm256_set1_ps(lr);
29 __m256 epsi_r = _mm256_set1_ps(epsilon);
30
31 float *var_ptr = var + start;
32 float *m_ptr = m + start;
33 float *v_ptr = v + start;
34 const float *grad_ptr = gradient + start;
35
36 __m256 avx_r0, avx_r1;
37 __m256 var_r, m_r, v_r, grad_r;
38
39 for (; c1 < start + c8; c1 += C8NUM) {
40 grad_r = _mm256_loadu_ps(grad_ptr);
41 m_r = _mm256_loadu_ps(m_ptr);
42 avx_r0 = _mm256_sub_ps(grad_r, m_r);
43 avx_r1 = _mm256_mul_ps(avx_r0, coeff1_r);
44 m_r = _mm256_add_ps(m_r, avx_r1);
45 _mm256_storeu_ps(m_ptr, m_r);
46
47 v_r = _mm256_loadu_ps(v_ptr);
48 avx_r0 = _mm256_sub_ps(_mm256_mul_ps(grad_r, grad_r), v_r);
49 v_r = _mm256_add_ps(v_r, _mm256_mul_ps(avx_r0, coeff2_r));
50 _mm256_storeu_ps(v_ptr, v_r);
51
52 if (use_nesterov) {
53 avx_r0 = _mm256_add_ps(_mm256_mul_ps(m_r, beta1_r), _mm256_mul_ps(coeff1_r, grad_r));
54 avx_r1 = _mm256_mul_ps(lr_r, avx_r0);
55 avx_r0 = _mm256_add_ps(_mm256_sqrt_ps(v_r), epsi_r);
56 __m256 avx_r2 = _mm256_div_ps(avx_r1, avx_r0);
57
58 var_r = _mm256_loadu_ps(var_ptr);
59 var_r = _mm256_sub_ps(var_r, avx_r2);
60 _mm256_storeu_ps(var_ptr, var_r);
61 } else {
62 avx_r0 = _mm256_mul_ps(lr_r, m_r);
63 avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(v_r), epsi_r);
64 __m256 avx_r2 = _mm256_div_ps(avx_r0, avx_r1);
65 var_r = _mm256_loadu_ps(var_ptr);
66 var_r = _mm256_sub_ps(var_r, avx_r2);
67 _mm256_storeu_ps(var_ptr, var_r);
68 }
69 m_ptr += C8NUM;
70 v_ptr += C8NUM;
71 var_ptr += C8NUM;
72 grad_ptr += C8NUM;
73 }
74 #endif
75
76 // remaining
77 for (; c1 < end; c1++) {
78 m[c1] += (gradient[c1] - m[c1]) * (1 - beta1);
79 v[c1] += (gradient[c1] * gradient[c1] - v[c1]) * (1 - beta2);
80 if (use_nesterov) {
81 var[c1] -= lr * (m[c1] * beta1 + (1 - beta1) * gradient[c1]) / (sqrt(v[c1]) + epsilon);
82 } else {
83 var[c1] -= lr * m[c1] / (sqrt(v[c1]) + epsilon);
84 }
85 }
86 return NNACL_OK;
87 }
88
AdamDeltaFp32(float * delta,float * m,float * v,float lr,float beta1,float beta2,float epsilon,const float * gradient,size_t start,size_t end,bool use_nesterov)89 int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
90 const float *gradient, size_t start, size_t end, bool use_nesterov) {
91 size_t c1 = start;
92 #ifdef ENABLE_AVX
93 size_t c8 = ((end - start) / C8NUM) * C8NUM;
94 __m256 coeff1_r = _mm256_set1_ps(1.0f - beta1);
95 __m256 coeff2_r = _mm256_set1_ps(1.0f - beta2);
96 __m256 beta1_r = _mm256_set1_ps(beta1);
97 __m256 beta2_r = _mm256_set1_ps(beta2);
98 __m256 lr_r = _mm256_set1_ps(-lr);
99 __m256 epsi_r = _mm256_set1_ps(epsilon);
100
101 float *m_ptr = m + start;
102 float *v_ptr = v + start;
103 float *delta_ptr = delta + start;
104 const float *gradient_ptr = gradient + start;
105
106 __m256 m_r, v_r, delta_r, grad_r;
107 __m256 avx_r0, avx_r1;
108 for (; c1 < start + c8; c1 += C8NUM) {
109 m_r = _mm256_loadu_ps(m_ptr);
110 avx_r0 = _mm256_mul_ps(m_r, beta1_r);
111 grad_r = _mm256_loadu_ps(gradient_ptr);
112 m_r = _mm256_add_ps(avx_r0, _mm256_mul_ps(coeff1_r, grad_r));
113 _mm256_storeu_ps(m_ptr, m_r);
114
115 v_r = _mm256_loadu_ps(v_ptr);
116 avx_r0 = _mm256_mul_ps(v_r, beta2_r);
117 avx_r1 = _mm256_mul_ps(_mm256_mul_ps(coeff2_r, grad_r), grad_r);
118 v_r = _mm256_add_ps(avx_r0, avx_r1);
119 _mm256_storeu_ps(v_ptr, v_r);
120
121 if (use_nesterov) {
122 avx_r0 = _mm256_add_ps(_mm256_mul_ps(m_r, beta1_r), _mm256_mul_ps(coeff1_r, grad_r));
123 avx_r0 = _mm256_mul_ps(lr_r, avx_r0);
124 avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(v_r), epsi_r);
125 delta_r = _mm256_div_ps(avx_r0, avx_r1);
126 _mm256_storeu_ps(delta_ptr, delta_r);
127 } else {
128 avx_r0 = _mm256_mul_ps(lr_r, m_r);
129 avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(v_r), epsi_r);
130 delta_r = _mm256_div_ps(avx_r0, avx_r1);
131 _mm256_storeu_ps(delta_ptr, delta_r);
132 }
133 m_ptr += C8NUM;
134 v_ptr += C8NUM;
135 delta_ptr += C8NUM;
136 gradient_ptr += C8NUM;
137 }
138 #endif
139
140 // remaining
141 for (; c1 < end; ++c1) {
142 m[c1] *= beta1;
143 m[c1] += (1 - beta1) * gradient[c1];
144 v[c1] *= beta2;
145 v[c1] += (1 - beta2) * gradient[c1] * gradient[c1];
146 if (use_nesterov) {
147 delta[c1] = -lr * (m[c1] * beta1 + (1 - beta1) * gradient[c1]) / (sqrt(v[c1]) + epsilon);
148 } else {
149 delta[c1] = -lr * m[c1] / (sqrt(v[c1]) + epsilon);
150 }
151 }
152 return NNACL_OK;
153 }
154
AdamWeightDecayFp32(float * var,float * m,float * v,float lr,float beta1,float beta2,float epsilon,float decay,const float * gradient,size_t start,size_t end)155 int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
156 const float *gradient, size_t start, size_t end) {
157 size_t c1 = start;
158 const float beta1_minus = 1 - beta1;
159 const float beta2_minus = 1 - beta2;
160 #ifdef ENABLE_AVX512
161 __m512 beta1_r = _mm512_set1_ps(beta1);
162 __m512 beta2_r = _mm512_set1_ps(beta2);
163 __m512 beta1_minus_r = _mm512_set1_ps(beta1_minus);
164 __m512 beta2_minus_r = _mm512_set1_ps(beta2_minus);
165 __m512 lr_neg_r = _mm512_set1_ps(-lr);
166 __m512 epsilon_r = _mm512_set1_ps(epsilon);
167 __m512 decay_r = _mm512_set1_ps(decay);
168 size_t c16 = ((end - start) / C16NUM) * C16NUM + start;
169
170 const float *gradient_ptr = gradient + start;
171 float *var_ptr = var + start;
172 float *m_ptr = m + start;
173 float *v_ptr = v + start;
174
175 for (; c1 < c16; c1 += C16NUM) {
176 __m512 var_r = _mm512_loadu_ps(var_ptr);
177 __m512 m_r = _mm512_loadu_ps(m_ptr);
178 __m512 v_r = _mm512_loadu_ps(v_ptr);
179 __m512 g_r = _mm512_loadu_ps(gradient_ptr);
180
181 m_r = _mm512_mul_ps(m_r, beta1_r);
182 v_r = _mm512_mul_ps(v_r, beta2_r);
183 __m512 avx_r0 = _mm512_mul_ps(g_r, g_r);
184 m_r = _mm512_fmadd_ps(g_r, beta1_minus_r, m_r);
185 v_r = _mm512_fmadd_ps(avx_r0, beta2_minus_r, v_r);
186 avx_r0 = _mm512_sqrt_ps(v_r);
187 avx_r0 = _mm512_div_ps(m_r, _mm512_add_ps(avx_r0, epsilon_r));
188 avx_r0 = _mm512_fmadd_ps(var_r, decay_r, avx_r0);
189 var_r = _mm512_fmadd_ps(avx_r0, lr_neg_r, var_r);
190 _mm512_storeu_ps(m_ptr, m_r);
191 _mm512_storeu_ps(v_ptr, v_r);
192 _mm512_storeu_ps(var_ptr, var_r);
193
194 gradient_ptr += C16NUM;
195 var_ptr += C16NUM;
196 m_ptr += C16NUM;
197 v_ptr += C16NUM;
198 }
199 #endif
200 // remaining
201 for (; c1 < end; c1++) {
202 m[c1] += (gradient[c1] - m[c1]) * beta1_minus;
203 v[c1] += (gradient[c1] * gradient[c1] - v[c1]) * beta2_minus;
204 var[c1] -= lr * (m[c1] / (sqrt(v[c1]) + epsilon) + decay * var[c1]);
205 }
206 return NNACL_OK;
207 }
208
FusedCastAdamFp32(float * var,float * m,float * v,float lr,float beta1,float beta2,float epsilon,float decay,const int16_t * gradient16,size_t start,size_t end)209 size_t FusedCastAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
210 const int16_t *gradient16, size_t start, size_t end) {
211 size_t c1 = start;
212 #ifdef ENABLE_AVX512
213 __m512 beta1_r = _mm512_set1_ps(beta1);
214 __m512 beta2_r = _mm512_set1_ps(beta2);
215 __m512 beta1_minus_r = _mm512_set1_ps(1.0f - beta1);
216 __m512 beta2_minus_r = _mm512_set1_ps(1.0f - beta2);
217 __m512 lr_neg_r = _mm512_set1_ps(-lr);
218 __m512 epsilon_r = _mm512_set1_ps(epsilon);
219 __m512 decay_r = _mm512_set1_ps(decay);
220 size_t c16 = ((end - start) / C16NUM) * C16NUM + start;
221
222 const int16_t *gradient16_ptr = gradient16 + start;
223 float *var_ptr = var + start;
224 float *m_ptr = m + start;
225 float *v_ptr = v + start;
226
227 for (; c1 < c16; c1 += C16NUM) {
228 __m512 var_r = _mm512_loadu_ps(var_ptr);
229 __m512 m_r = _mm512_loadu_ps(m_ptr);
230 __m512 v_r = _mm512_loadu_ps(v_ptr);
231 __m512 g_r = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr)));
232
233 m_r = _mm512_mul_ps(m_r, beta1_r);
234 v_r = _mm512_mul_ps(v_r, beta2_r);
235 __m512 avx_r0 = _mm512_mul_ps(g_r, g_r);
236 m_r = _mm512_fmadd_ps(g_r, beta1_minus_r, m_r);
237 v_r = _mm512_fmadd_ps(avx_r0, beta2_minus_r, v_r);
238 avx_r0 = _mm512_sqrt_ps(v_r);
239 avx_r0 = _mm512_div_ps(m_r, _mm512_add_ps(avx_r0, epsilon_r));
240 avx_r0 = _mm512_fmadd_ps(var_r, decay_r, avx_r0);
241 var_r = _mm512_fmadd_ps(avx_r0, lr_neg_r, var_r);
242 _mm512_storeu_ps(var_ptr, var_r);
243 _mm512_storeu_ps(m_ptr, m_r);
244 _mm512_storeu_ps(v_ptr, v_r);
245
246 gradient16_ptr += C16NUM;
247 var_ptr += C16NUM;
248 m_ptr += C16NUM;
249 v_ptr += C16NUM;
250 }
251 #endif
252 return c1;
253 }
254
FusedCastAdamFp16(int16_t * var16,float * m,float * v,float lr,float beta1,float beta2,float epsilon,float decay,const int16_t * gradient16,size_t start,size_t end)255 size_t FusedCastAdamFp16(int16_t *var16, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
256 float decay, const int16_t *gradient16, size_t start, size_t end) {
257 size_t c1 = start;
258 #ifdef ENABLE_AVX512
259 __m512 beta1_r = _mm512_set1_ps(beta1);
260 __m512 beta2_r = _mm512_set1_ps(beta2);
261 __m512 beta1_minus_r = _mm512_set1_ps(1.0f - beta1);
262 __m512 beta2_minus_r = _mm512_set1_ps(1.0f - beta2);
263 __m512 lr_neg_r = _mm512_set1_ps(-lr);
264 __m512 epsilon_r = _mm512_set1_ps(epsilon);
265 __m512 decay_r = _mm512_set1_ps(decay);
266 size_t c16 = ((end - start) / C16NUM) * C16NUM + start;
267
268 const int16_t *gradient16_ptr = gradient16 + start;
269 int16_t *var16_ptr = var16 + start;
270 float *m_ptr = m + start;
271 float *v_ptr = v + start;
272
273 for (; c1 < c16; c1 += C16NUM) {
274 __m512 var_r = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(var16_ptr)));
275 __m512 m_r = _mm512_loadu_ps(m_ptr);
276 __m512 v_r = _mm512_loadu_ps(v_ptr);
277 __m512 g_r = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr)));
278
279 m_r = _mm512_mul_ps(m_r, beta1_r);
280 v_r = _mm512_mul_ps(v_r, beta2_r);
281 __m512 avx_r0 = _mm512_mul_ps(g_r, g_r);
282 m_r = _mm512_fmadd_ps(g_r, beta1_minus_r, m_r);
283 v_r = _mm512_fmadd_ps(avx_r0, beta2_minus_r, v_r);
284 avx_r0 = _mm512_sqrt_ps(v_r);
285 avx_r0 = _mm512_div_ps(m_r, _mm512_add_ps(avx_r0, epsilon_r));
286 avx_r0 = _mm512_fmadd_ps(var_r, decay_r, avx_r0);
287 var_r = _mm512_fmadd_ps(avx_r0, lr_neg_r, var_r);
288 _mm512_storeu_ps(m_ptr, m_r);
289 _mm512_storeu_ps(v_ptr, v_r);
290 _mm256_storeu_si256((__m256i *)var16_ptr, _mm512_cvtps_ph(var_r, 0));
291
292 gradient16_ptr += C16NUM;
293 var16_ptr += C16NUM;
294 m_ptr += C16NUM;
295 v_ptr += C16NUM;
296 }
297 #endif
298 return c1;
299 }
300