• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifdef ENABLE_SSE
17 #ifdef SUPPORT_MSVC
18 #include <immintrin.h>
19 #else
20 #include <x86intrin.h>
21 #endif
22 #endif
23 
24 #ifdef ENABLE_AVX
25 #include <immintrin.h>
26 #endif
27 
28 #include <math.h>
29 #include "nnacl/fp32/rmsprop_fp32.h"
30 
RMSPropUnuseCenterFp32(float * variable,float * mean_square,float * moment,float * gradients,float momentum,float learning_rate,float decay,float epsilon,size_t start,size_t end)31 int RMSPropUnuseCenterFp32(float *variable, float *mean_square, float *moment, float *gradients, float momentum,
32                            float learning_rate, float decay, float epsilon, size_t start, size_t end) {
33   size_t c1 = start;
34 #ifdef ENABLE_AVX
35   size_t c8 = ((end - start) / C8NUM) * C8NUM;
36   float *variable_ptr = variable + start;
37   float *mean_square_ptr = mean_square + start;
38   float *gradients_ptr = gradients + start;
39   float *moment_ptr = moment + start;
40 
41   __m256 decay_r = _mm256_set1_ps(1.0 - decay);
42   __m256 momentum_r = _mm256_set1_ps(momentum);
43   __m256 lr_r = _mm256_set1_ps(learning_rate);
44   __m256 epsi_r = _mm256_set1_ps(epsilon);
45   __m256 gradient_r, mean_square_r, moment_r, variable_r, avx_r1, avx_r2;
46   for (; c1 < start + c8; c1 += C8NUM) {
47     gradient_r = _mm256_loadu_ps(gradients_ptr);
48     mean_square_r = _mm256_loadu_ps(mean_square_ptr);
49     avx_r1 = _mm256_sub_ps(_mm256_mul_ps(gradient_r, gradient_r), mean_square_r);
50     avx_r2 = _mm256_mul_ps(avx_r1, decay_r);
51     mean_square_r = _mm256_add_ps(mean_square_r, avx_r2);
52     _mm256_storeu_ps(mean_square_ptr, mean_square_r);
53 
54     avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(mean_square_r), epsi_r);
55     avx_r2 = _mm256_div_ps(_mm256_mul_ps(gradient_r, lr_r), avx_r1);
56 
57     moment_r = _mm256_loadu_ps(moment_ptr);
58     avx_r1 = _mm256_add_ps(_mm256_mul_ps(moment_r, momentum_r), avx_r2);
59     _mm256_storeu_ps(moment_ptr, avx_r1);
60 
61     variable_r = _mm256_loadu_ps(variable_ptr);
62     variable_r = _mm256_sub_ps(variable_r, avx_r1);
63     _mm256_storeu_ps(variable_ptr, variable_r);
64 
65     gradients_ptr += C8NUM;
66     mean_square_ptr += C8NUM;
67     moment_ptr += C8NUM;
68     variable_ptr += C8NUM;
69   }
70 #endif
71 
72   for (; c1 < end; c1++) {
73     mean_square[c1] += (gradients[c1] * gradients[c1] - mean_square[c1]) * (1.0 - decay);
74     moment[c1] = moment[c1] * momentum + (gradients[c1] * learning_rate) / sqrt(mean_square[c1] + epsilon);
75     variable[c1] -= moment[c1];
76   }
77   return NNACL_OK;
78 }
79 
RMSPropUseCenterFp32(float * variable,float * mean_square,float * moment,float * gradients,float * mean_gradients,float momentum,float learning_rate,float decay,float epsilon,size_t start,size_t end)80 int RMSPropUseCenterFp32(float *variable, float *mean_square, float *moment, float *gradients, float *mean_gradients,
81                          float momentum, float learning_rate, float decay, float epsilon, size_t start, size_t end) {
82   size_t c1 = start;
83 #ifdef ENABLE_AVX
84   size_t c8 = ((end - start) / C8NUM) * C8NUM;
85   float *variable_ptr = variable + start;
86   float *mean_gradients_ptr = mean_gradients + start;
87   float *mean_square_ptr = mean_square + start;
88   float *moment_ptr = moment + start;
89   float *gradients_ptr = gradients + start;
90 
91   __m256 decay_r = _mm256_set1_ps(1.0 - decay);
92   __m256 momentum_r = _mm256_set1_ps(momentum);
93   __m256 lr_r = _mm256_set1_ps(learning_rate);
94   __m256 epsi_r = _mm256_set1_ps(epsilon);
95   __m256 grad_r, mean_grad_r, mean_square_r, moment_r, variable_r;
96   __m256 avx_r1, avx_r2;
97   for (; c1 < start + c8; c1 += C8NUM) {
98     grad_r = _mm256_loadu_ps(gradients_ptr);
99     mean_square_r = _mm256_loadu_ps(mean_square_ptr);
100     avx_r1 = _mm256_sub_ps(_mm256_mul_ps(grad_r, grad_r), mean_square_r);
101     avx_r2 = _mm256_mul_ps(avx_r1, decay_r);
102     mean_square_r = _mm256_add_ps(mean_square_r, avx_r2);
103     _mm256_storeu_ps(mean_square_ptr, mean_square_r);
104 
105     mean_grad_r = _mm256_loadu_ps(mean_gradients_ptr);
106     avx_r1 = _mm256_mul_ps(_mm256_sub_ps(grad_r, mean_grad_r), decay_r);
107     mean_grad_r = _mm256_add_ps(mean_grad_r, avx_r1);
108     _mm256_storeu_ps(mean_gradients_ptr, mean_grad_r);
109 
110     avx_r1 = _mm256_sub_ps(mean_square_r, _mm256_mul_ps(mean_grad_r, mean_grad_r));
111     __m256 denom_r = _mm256_add_ps(avx_r1, epsi_r);
112     __m256 cmp_r = _mm256_cmp_ps(denom_r, _mm256_setzero_ps(), _CMP_GE_OS);
113     __m256 gt_zero_r = _mm256_blendv_ps(_mm256_set1_ps(1.0f), denom_r, cmp_r);
114 
115     avx_r1 = _mm256_mul_ps(grad_r, lr_r);
116     avx_r2 = _mm256_div_ps(avx_r1, _mm256_sqrt_ps(gt_zero_r));
117     moment_r = _mm256_loadu_ps(moment_ptr);
118     avx_r1 = _mm256_mul_ps(moment_r, momentum_r);
119     avx_r1 = _mm256_add_ps(avx_r1, avx_r2);
120     moment_r = _mm256_blendv_ps(moment_r, avx_r1, cmp_r);
121     _mm256_storeu_ps(moment_ptr, moment_r);
122 
123     variable_r = _mm256_loadu_ps(variable_ptr);
124     avx_r1 = _mm256_sub_ps(variable_r, moment_r);
125     variable_r = _mm256_blendv_ps(variable_r, avx_r1, cmp_r);
126     _mm256_storeu_ps(variable_ptr, variable_r);
127 
128     variable_ptr += C8NUM;
129     mean_gradients_ptr += C8NUM;
130     mean_square_ptr += C8NUM;
131     gradients_ptr += C8NUM;
132     moment_ptr += C8NUM;
133   }
134 #endif
135 
136   for (; c1 < end; c1++) {
137     mean_square[c1] += (gradients[c1] * gradients[c1] - mean_square[c1]) * (1.0 - decay);
138     mean_gradients[c1] += (gradients[c1] - mean_gradients[c1]) * (1.0 - decay);
139     float denom = (mean_square[c1] - mean_gradients[c1] * mean_gradients[c1]) + epsilon;
140     if (denom > 0) {
141       moment[c1] = moment[c1] * momentum + (gradients[c1] * learning_rate) / sqrt(denom);
142       variable[c1] -= moment[c1];
143     }
144   }
145   return NNACL_OK;
146 }
147