1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #ifdef ENABLE_SSE
17 #ifdef SUPPORT_MSVC
18 #include <immintrin.h>
19 #else
20 #include <x86intrin.h>
21 #endif
22 #endif
23
24 #ifdef ENABLE_AVX
25 #include <immintrin.h>
26 #endif
27
28 #include <math.h>
29 #include "nnacl/fp32/rmsprop_fp32.h"
30
RMSPropUnuseCenterFp32(float * variable,float * mean_square,float * moment,float * gradients,float momentum,float learning_rate,float decay,float epsilon,size_t start,size_t end)31 int RMSPropUnuseCenterFp32(float *variable, float *mean_square, float *moment, float *gradients, float momentum,
32 float learning_rate, float decay, float epsilon, size_t start, size_t end) {
33 size_t c1 = start;
34 #ifdef ENABLE_AVX
35 size_t c8 = ((end - start) / C8NUM) * C8NUM;
36 float *variable_ptr = variable + start;
37 float *mean_square_ptr = mean_square + start;
38 float *gradients_ptr = gradients + start;
39 float *moment_ptr = moment + start;
40
41 __m256 decay_r = _mm256_set1_ps(1.0 - decay);
42 __m256 momentum_r = _mm256_set1_ps(momentum);
43 __m256 lr_r = _mm256_set1_ps(learning_rate);
44 __m256 epsi_r = _mm256_set1_ps(epsilon);
45 __m256 gradient_r, mean_square_r, moment_r, variable_r, avx_r1, avx_r2;
46 for (; c1 < start + c8; c1 += C8NUM) {
47 gradient_r = _mm256_loadu_ps(gradients_ptr);
48 mean_square_r = _mm256_loadu_ps(mean_square_ptr);
49 avx_r1 = _mm256_sub_ps(_mm256_mul_ps(gradient_r, gradient_r), mean_square_r);
50 avx_r2 = _mm256_mul_ps(avx_r1, decay_r);
51 mean_square_r = _mm256_add_ps(mean_square_r, avx_r2);
52 _mm256_storeu_ps(mean_square_ptr, mean_square_r);
53
54 avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(mean_square_r), epsi_r);
55 avx_r2 = _mm256_div_ps(_mm256_mul_ps(gradient_r, lr_r), avx_r1);
56
57 moment_r = _mm256_loadu_ps(moment_ptr);
58 avx_r1 = _mm256_add_ps(_mm256_mul_ps(moment_r, momentum_r), avx_r2);
59 _mm256_storeu_ps(moment_ptr, avx_r1);
60
61 variable_r = _mm256_loadu_ps(variable_ptr);
62 variable_r = _mm256_sub_ps(variable_r, avx_r1);
63 _mm256_storeu_ps(variable_ptr, variable_r);
64
65 gradients_ptr += C8NUM;
66 mean_square_ptr += C8NUM;
67 moment_ptr += C8NUM;
68 variable_ptr += C8NUM;
69 }
70 #endif
71
72 for (; c1 < end; c1++) {
73 mean_square[c1] += (gradients[c1] * gradients[c1] - mean_square[c1]) * (1.0 - decay);
74 moment[c1] = moment[c1] * momentum + (gradients[c1] * learning_rate) / sqrt(mean_square[c1] + epsilon);
75 variable[c1] -= moment[c1];
76 }
77 return NNACL_OK;
78 }
79
RMSPropUseCenterFp32(float * variable,float * mean_square,float * moment,float * gradients,float * mean_gradients,float momentum,float learning_rate,float decay,float epsilon,size_t start,size_t end)80 int RMSPropUseCenterFp32(float *variable, float *mean_square, float *moment, float *gradients, float *mean_gradients,
81 float momentum, float learning_rate, float decay, float epsilon, size_t start, size_t end) {
82 size_t c1 = start;
83 #ifdef ENABLE_AVX
84 size_t c8 = ((end - start) / C8NUM) * C8NUM;
85 float *variable_ptr = variable + start;
86 float *mean_gradients_ptr = mean_gradients + start;
87 float *mean_square_ptr = mean_square + start;
88 float *moment_ptr = moment + start;
89 float *gradients_ptr = gradients + start;
90
91 __m256 decay_r = _mm256_set1_ps(1.0 - decay);
92 __m256 momentum_r = _mm256_set1_ps(momentum);
93 __m256 lr_r = _mm256_set1_ps(learning_rate);
94 __m256 epsi_r = _mm256_set1_ps(epsilon);
95 __m256 grad_r, mean_grad_r, mean_square_r, moment_r, variable_r;
96 __m256 avx_r1, avx_r2;
97 for (; c1 < start + c8; c1 += C8NUM) {
98 grad_r = _mm256_loadu_ps(gradients_ptr);
99 mean_square_r = _mm256_loadu_ps(mean_square_ptr);
100 avx_r1 = _mm256_sub_ps(_mm256_mul_ps(grad_r, grad_r), mean_square_r);
101 avx_r2 = _mm256_mul_ps(avx_r1, decay_r);
102 mean_square_r = _mm256_add_ps(mean_square_r, avx_r2);
103 _mm256_storeu_ps(mean_square_ptr, mean_square_r);
104
105 mean_grad_r = _mm256_loadu_ps(mean_gradients_ptr);
106 avx_r1 = _mm256_mul_ps(_mm256_sub_ps(grad_r, mean_grad_r), decay_r);
107 mean_grad_r = _mm256_add_ps(mean_grad_r, avx_r1);
108 _mm256_storeu_ps(mean_gradients_ptr, mean_grad_r);
109
110 avx_r1 = _mm256_sub_ps(mean_square_r, _mm256_mul_ps(mean_grad_r, mean_grad_r));
111 __m256 denom_r = _mm256_add_ps(avx_r1, epsi_r);
112 __m256 cmp_r = _mm256_cmp_ps(denom_r, _mm256_setzero_ps(), _CMP_GE_OS);
113 __m256 gt_zero_r = _mm256_blendv_ps(_mm256_set1_ps(1.0f), denom_r, cmp_r);
114
115 avx_r1 = _mm256_mul_ps(grad_r, lr_r);
116 avx_r2 = _mm256_div_ps(avx_r1, _mm256_sqrt_ps(gt_zero_r));
117 moment_r = _mm256_loadu_ps(moment_ptr);
118 avx_r1 = _mm256_mul_ps(moment_r, momentum_r);
119 avx_r1 = _mm256_add_ps(avx_r1, avx_r2);
120 moment_r = _mm256_blendv_ps(moment_r, avx_r1, cmp_r);
121 _mm256_storeu_ps(moment_ptr, moment_r);
122
123 variable_r = _mm256_loadu_ps(variable_ptr);
124 avx_r1 = _mm256_sub_ps(variable_r, moment_r);
125 variable_r = _mm256_blendv_ps(variable_r, avx_r1, cmp_r);
126 _mm256_storeu_ps(variable_ptr, variable_r);
127
128 variable_ptr += C8NUM;
129 mean_gradients_ptr += C8NUM;
130 mean_square_ptr += C8NUM;
131 gradients_ptr += C8NUM;
132 moment_ptr += C8NUM;
133 }
134 #endif
135
136 for (; c1 < end; c1++) {
137 mean_square[c1] += (gradients[c1] * gradients[c1] - mean_square[c1]) * (1.0 - decay);
138 mean_gradients[c1] += (gradients[c1] - mean_gradients[c1]) * (1.0 - decay);
139 float denom = (mean_square[c1] - mean_gradients[c1] * mean_gradients[c1]) + epsilon;
140 if (denom > 0) {
141 moment[c1] = moment[c1] * momentum + (gradients[c1] * learning_rate) / sqrt(denom);
142 variable[c1] -= moment[c1];
143 }
144 }
145 return NNACL_OK;
146 }
147