• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <math.h>
17 #include "nnacl/fp32/exp_fp32.h"
18 #include "nnacl/fp32/adam_fp32.h"
19 
AdamFp32(float * var,float * m,float * v,float lr,float beta1,float beta2,float epsilon,const float * gradient,size_t start,size_t end,bool use_nesterov)20 int AdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, const float *gradient,
21              size_t start, size_t end, bool use_nesterov) {
22   size_t c1 = start;
23 #ifdef ENABLE_AVX
24   size_t c8 = ((end - start) / C8NUM) * C8NUM;
25   __m256 coeff1_r = _mm256_set1_ps(1 - beta1);
26   __m256 coeff2_r = _mm256_set1_ps(1 - beta2);
27   __m256 beta1_r = _mm256_set1_ps(beta1);
28   __m256 lr_r = _mm256_set1_ps(lr);
29   __m256 epsi_r = _mm256_set1_ps(epsilon);
30 
31   float *var_ptr = var + start;
32   float *m_ptr = m + start;
33   float *v_ptr = v + start;
34   const float *grad_ptr = gradient + start;
35 
36   __m256 avx_r0, avx_r1;
37   __m256 var_r, m_r, v_r, grad_r;
38 
39   for (; c1 < start + c8; c1 += C8NUM) {
40     grad_r = _mm256_loadu_ps(grad_ptr);
41     m_r = _mm256_loadu_ps(m_ptr);
42     avx_r0 = _mm256_sub_ps(grad_r, m_r);
43     avx_r1 = _mm256_mul_ps(avx_r0, coeff1_r);
44     m_r = _mm256_add_ps(m_r, avx_r1);
45     _mm256_storeu_ps(m_ptr, m_r);
46 
47     v_r = _mm256_loadu_ps(v_ptr);
48     avx_r0 = _mm256_sub_ps(_mm256_mul_ps(grad_r, grad_r), v_r);
49     v_r = _mm256_add_ps(v_r, _mm256_mul_ps(avx_r0, coeff2_r));
50     _mm256_storeu_ps(v_ptr, v_r);
51 
52     if (use_nesterov) {
53       avx_r0 = _mm256_add_ps(_mm256_mul_ps(m_r, beta1_r), _mm256_mul_ps(coeff1_r, grad_r));
54       avx_r1 = _mm256_mul_ps(lr_r, avx_r0);
55       avx_r0 = _mm256_add_ps(_mm256_sqrt_ps(v_r), epsi_r);
56       __m256 avx_r2 = _mm256_div_ps(avx_r1, avx_r0);
57 
58       var_r = _mm256_loadu_ps(var_ptr);
59       var_r = _mm256_sub_ps(var_r, avx_r2);
60       _mm256_storeu_ps(var_ptr, var_r);
61     } else {
62       avx_r0 = _mm256_mul_ps(lr_r, m_r);
63       avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(v_r), epsi_r);
64       __m256 avx_r2 = _mm256_div_ps(avx_r0, avx_r1);
65       var_r = _mm256_loadu_ps(var_ptr);
66       var_r = _mm256_sub_ps(var_r, avx_r2);
67       _mm256_storeu_ps(var_ptr, var_r);
68     }
69     m_ptr += C8NUM;
70     v_ptr += C8NUM;
71     var_ptr += C8NUM;
72     grad_ptr += C8NUM;
73   }
74 #endif
75 
76   // remaining
77   for (; c1 < end; c1++) {
78     m[c1] += (gradient[c1] - m[c1]) * (1 - beta1);
79     v[c1] += (gradient[c1] * gradient[c1] - v[c1]) * (1 - beta2);
80     if (use_nesterov) {
81       var[c1] -= lr * (m[c1] * beta1 + (1 - beta1) * gradient[c1]) / (sqrt(v[c1]) + epsilon);
82     } else {
83       var[c1] -= lr * m[c1] / (sqrt(v[c1]) + epsilon);
84     }
85   }
86   return NNACL_OK;
87 }
88 
AdamDeltaFp32(float * delta,float * m,float * v,float lr,float beta1,float beta2,float epsilon,const float * gradient,size_t start,size_t end,bool use_nesterov)89 int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
90                   const float *gradient, size_t start, size_t end, bool use_nesterov) {
91   size_t c1 = start;
92 #ifdef ENABLE_AVX
93   size_t c8 = ((end - start) / C8NUM) * C8NUM;
94   __m256 coeff1_r = _mm256_set1_ps(1.0f - beta1);
95   __m256 coeff2_r = _mm256_set1_ps(1.0f - beta2);
96   __m256 beta1_r = _mm256_set1_ps(beta1);
97   __m256 beta2_r = _mm256_set1_ps(beta2);
98   __m256 lr_r = _mm256_set1_ps(-lr);
99   __m256 epsi_r = _mm256_set1_ps(epsilon);
100 
101   float *m_ptr = m + start;
102   float *v_ptr = v + start;
103   float *delta_ptr = delta + start;
104   const float *gradient_ptr = gradient + start;
105 
106   __m256 m_r, v_r, delta_r, grad_r;
107   __m256 avx_r0, avx_r1;
108   for (; c1 < start + c8; c1 += C8NUM) {
109     m_r = _mm256_loadu_ps(m_ptr);
110     avx_r0 = _mm256_mul_ps(m_r, beta1_r);
111     grad_r = _mm256_loadu_ps(gradient_ptr);
112     m_r = _mm256_add_ps(avx_r0, _mm256_mul_ps(coeff1_r, grad_r));
113     _mm256_storeu_ps(m_ptr, m_r);
114 
115     v_r = _mm256_loadu_ps(v_ptr);
116     avx_r0 = _mm256_mul_ps(v_r, beta2_r);
117     avx_r1 = _mm256_mul_ps(_mm256_mul_ps(coeff2_r, grad_r), grad_r);
118     v_r = _mm256_add_ps(avx_r0, avx_r1);
119     _mm256_storeu_ps(v_ptr, v_r);
120 
121     if (use_nesterov) {
122       avx_r0 = _mm256_add_ps(_mm256_mul_ps(m_r, beta1_r), _mm256_mul_ps(coeff1_r, grad_r));
123       avx_r0 = _mm256_mul_ps(lr_r, avx_r0);
124       avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(v_r), epsi_r);
125       delta_r = _mm256_div_ps(avx_r0, avx_r1);
126       _mm256_storeu_ps(delta_ptr, delta_r);
127     } else {
128       avx_r0 = _mm256_mul_ps(lr_r, m_r);
129       avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(v_r), epsi_r);
130       delta_r = _mm256_div_ps(avx_r0, avx_r1);
131       _mm256_storeu_ps(delta_ptr, delta_r);
132     }
133     m_ptr += C8NUM;
134     v_ptr += C8NUM;
135     delta_ptr += C8NUM;
136     gradient_ptr += C8NUM;
137   }
138 #endif
139 
140   // remaining
141   for (; c1 < end; ++c1) {
142     m[c1] *= beta1;
143     m[c1] += (1 - beta1) * gradient[c1];
144     v[c1] *= beta2;
145     v[c1] += (1 - beta2) * gradient[c1] * gradient[c1];
146     if (use_nesterov) {
147       delta[c1] = -lr * (m[c1] * beta1 + (1 - beta1) * gradient[c1]) / (sqrt(v[c1]) + epsilon);
148     } else {
149       delta[c1] = -lr * m[c1] / (sqrt(v[c1]) + epsilon);
150     }
151   }
152   return NNACL_OK;
153 }
154 
AdamWeightDecayFp32(float * var,float * m,float * v,float lr,float beta1,float beta2,float epsilon,float decay,const float * gradient,size_t start,size_t end)155 int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
156                         const float *gradient, size_t start, size_t end) {
157   size_t c1 = start;
158   const float beta1_minus = 1 - beta1;
159   const float beta2_minus = 1 - beta2;
160 #ifdef ENABLE_AVX512
161   __m512 beta1_r = _mm512_set1_ps(beta1);
162   __m512 beta2_r = _mm512_set1_ps(beta2);
163   __m512 beta1_minus_r = _mm512_set1_ps(beta1_minus);
164   __m512 beta2_minus_r = _mm512_set1_ps(beta2_minus);
165   __m512 lr_neg_r = _mm512_set1_ps(-lr);
166   __m512 epsilon_r = _mm512_set1_ps(epsilon);
167   __m512 decay_r = _mm512_set1_ps(decay);
168   size_t c16 = ((end - start) / C16NUM) * C16NUM + start;
169 
170   const float *gradient_ptr = gradient + start;
171   float *var_ptr = var + start;
172   float *m_ptr = m + start;
173   float *v_ptr = v + start;
174 
175   for (; c1 < c16; c1 += C16NUM) {
176     __m512 var_r = _mm512_loadu_ps(var_ptr);
177     __m512 m_r = _mm512_loadu_ps(m_ptr);
178     __m512 v_r = _mm512_loadu_ps(v_ptr);
179     __m512 g_r = _mm512_loadu_ps(gradient_ptr);
180 
181     m_r = _mm512_mul_ps(m_r, beta1_r);
182     v_r = _mm512_mul_ps(v_r, beta2_r);
183     __m512 avx_r0 = _mm512_mul_ps(g_r, g_r);
184     m_r = _mm512_fmadd_ps(g_r, beta1_minus_r, m_r);
185     v_r = _mm512_fmadd_ps(avx_r0, beta2_minus_r, v_r);
186     avx_r0 = _mm512_sqrt_ps(v_r);
187     avx_r0 = _mm512_div_ps(m_r, _mm512_add_ps(avx_r0, epsilon_r));
188     avx_r0 = _mm512_fmadd_ps(var_r, decay_r, avx_r0);
189     var_r = _mm512_fmadd_ps(avx_r0, lr_neg_r, var_r);
190     _mm512_storeu_ps(m_ptr, m_r);
191     _mm512_storeu_ps(v_ptr, v_r);
192     _mm512_storeu_ps(var_ptr, var_r);
193 
194     gradient_ptr += C16NUM;
195     var_ptr += C16NUM;
196     m_ptr += C16NUM;
197     v_ptr += C16NUM;
198   }
199 #endif
200   // remaining
201   for (; c1 < end; c1++) {
202     m[c1] += (gradient[c1] - m[c1]) * beta1_minus;
203     v[c1] += (gradient[c1] * gradient[c1] - v[c1]) * beta2_minus;
204     var[c1] -= lr * (m[c1] / (sqrt(v[c1]) + epsilon) + decay * var[c1]);
205   }
206   return NNACL_OK;
207 }
208 
FusedCastAdamFp32(float * var,float * m,float * v,float lr,float beta1,float beta2,float epsilon,float decay,const int16_t * gradient16,size_t start,size_t end)209 size_t FusedCastAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
210                          const int16_t *gradient16, size_t start, size_t end) {
211   size_t c1 = start;
212 #ifdef ENABLE_AVX512
213   __m512 beta1_r = _mm512_set1_ps(beta1);
214   __m512 beta2_r = _mm512_set1_ps(beta2);
215   __m512 beta1_minus_r = _mm512_set1_ps(1.0f - beta1);
216   __m512 beta2_minus_r = _mm512_set1_ps(1.0f - beta2);
217   __m512 lr_neg_r = _mm512_set1_ps(-lr);
218   __m512 epsilon_r = _mm512_set1_ps(epsilon);
219   __m512 decay_r = _mm512_set1_ps(decay);
220   size_t c16 = ((end - start) / C16NUM) * C16NUM + start;
221 
222   const int16_t *gradient16_ptr = gradient16 + start;
223   float *var_ptr = var + start;
224   float *m_ptr = m + start;
225   float *v_ptr = v + start;
226 
227   for (; c1 < c16; c1 += C16NUM) {
228     __m512 var_r = _mm512_loadu_ps(var_ptr);
229     __m512 m_r = _mm512_loadu_ps(m_ptr);
230     __m512 v_r = _mm512_loadu_ps(v_ptr);
231     __m512 g_r = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr)));
232 
233     m_r = _mm512_mul_ps(m_r, beta1_r);
234     v_r = _mm512_mul_ps(v_r, beta2_r);
235     __m512 avx_r0 = _mm512_mul_ps(g_r, g_r);
236     m_r = _mm512_fmadd_ps(g_r, beta1_minus_r, m_r);
237     v_r = _mm512_fmadd_ps(avx_r0, beta2_minus_r, v_r);
238     avx_r0 = _mm512_sqrt_ps(v_r);
239     avx_r0 = _mm512_div_ps(m_r, _mm512_add_ps(avx_r0, epsilon_r));
240     avx_r0 = _mm512_fmadd_ps(var_r, decay_r, avx_r0);
241     var_r = _mm512_fmadd_ps(avx_r0, lr_neg_r, var_r);
242     _mm512_storeu_ps(var_ptr, var_r);
243     _mm512_storeu_ps(m_ptr, m_r);
244     _mm512_storeu_ps(v_ptr, v_r);
245 
246     gradient16_ptr += C16NUM;
247     var_ptr += C16NUM;
248     m_ptr += C16NUM;
249     v_ptr += C16NUM;
250   }
251 #endif
252   return c1;
253 }
254 
FusedCastAdamFp16(int16_t * var16,float * m,float * v,float lr,float beta1,float beta2,float epsilon,float decay,const int16_t * gradient16,size_t start,size_t end)255 size_t FusedCastAdamFp16(int16_t *var16, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
256                          float decay, const int16_t *gradient16, size_t start, size_t end) {
257   size_t c1 = start;
258 #ifdef ENABLE_AVX512
259   __m512 beta1_r = _mm512_set1_ps(beta1);
260   __m512 beta2_r = _mm512_set1_ps(beta2);
261   __m512 beta1_minus_r = _mm512_set1_ps(1.0f - beta1);
262   __m512 beta2_minus_r = _mm512_set1_ps(1.0f - beta2);
263   __m512 lr_neg_r = _mm512_set1_ps(-lr);
264   __m512 epsilon_r = _mm512_set1_ps(epsilon);
265   __m512 decay_r = _mm512_set1_ps(decay);
266   size_t c16 = ((end - start) / C16NUM) * C16NUM + start;
267 
268   const int16_t *gradient16_ptr = gradient16 + start;
269   int16_t *var16_ptr = var16 + start;
270   float *m_ptr = m + start;
271   float *v_ptr = v + start;
272 
273   for (; c1 < c16; c1 += C16NUM) {
274     __m512 var_r = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(var16_ptr)));
275     __m512 m_r = _mm512_loadu_ps(m_ptr);
276     __m512 v_r = _mm512_loadu_ps(v_ptr);
277     __m512 g_r = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr)));
278 
279     m_r = _mm512_mul_ps(m_r, beta1_r);
280     v_r = _mm512_mul_ps(v_r, beta2_r);
281     __m512 avx_r0 = _mm512_mul_ps(g_r, g_r);
282     m_r = _mm512_fmadd_ps(g_r, beta1_minus_r, m_r);
283     v_r = _mm512_fmadd_ps(avx_r0, beta2_minus_r, v_r);
284     avx_r0 = _mm512_sqrt_ps(v_r);
285     avx_r0 = _mm512_div_ps(m_r, _mm512_add_ps(avx_r0, epsilon_r));
286     avx_r0 = _mm512_fmadd_ps(var_r, decay_r, avx_r0);
287     var_r = _mm512_fmadd_ps(avx_r0, lr_neg_r, var_r);
288     _mm512_storeu_ps(m_ptr, m_r);
289     _mm512_storeu_ps(v_ptr, v_r);
290     _mm256_storeu_si256((__m256i *)var16_ptr, _mm512_cvtps_ph(var_r, 0));
291 
292     gradient16_ptr += C16NUM;
293     var16_ptr += C16NUM;
294     m_ptr += C16NUM;
295     v_ptr += C16NUM;
296   }
297 #endif
298   return c1;
299 }
300