1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/kernel_compiler/gpu/cuda_impl/sparse_apply_proximal_adagrad_impl.cuh"
18
19 template <typename T>
CompareFunc(T x,T y)20 __device__ __forceinline__ bool CompareFunc(T x, T y) {
21 return x > y;
22 }
23
24 template <>
CompareFunc(half x,half y)25 __device__ __forceinline__ bool CompareFunc(half x, half y) {
26 return __half2float(x) > __half2float(y);
27 }
28
29 template <typename T>
RsqrtFunc(T x)30 __device__ __forceinline__ T RsqrtFunc(T x) {
31 return __frsqrt_rn(x);
32 }
33
34 template <>
RsqrtFunc(half x)35 __device__ __forceinline__ half RsqrtFunc(half x) {
36 return hrsqrt(x);
37 }
38
39 template <typename T>
AbsFunc(T x)40 __device__ __forceinline__ T AbsFunc(T x) {
41 return abs(x);
42 }
43
44 template <>
AbsFunc(half x)45 __device__ __forceinline__ half AbsFunc(half x) {
46 return __float2half(abs(__half2float(x)));
47 }
48
49 template <typename T>
Sgn(T x)50 __device__ __forceinline__ T Sgn(T x) {
51 return static_cast<T>(x != 0 ? (x > 0 ? 1 : -1) : 0);
52 }
53
54 template <>
Sgn(half x)55 __device__ __forceinline__ half Sgn(half x) {
56 return __float2half(__half2float(x) != 0 ? (__half2float(x) > 0 ? 1 : -1) : 0);
57 }
58
59 template <typename T>
SparseApplyProximalAdagradUpdate(const size_t size,const size_t indices_size,const T * learning_rate,const T * l1_regularization,const T * l2_regularization,const T * gradient,const int * indices,T * variable,T * accumulation,T * variable_out,T * accumulation_out)60 __global__ void SparseApplyProximalAdagradUpdate(const size_t size, const size_t indices_size, const T *learning_rate,
61 const T *l1_regularization, const T *l2_regularization,
62 const T *gradient, const int *indices, T *variable, T *accumulation,
63 T *variable_out, T *accumulation_out) {
64 const int inner_size = static_cast<int>(size / indices_size);
65 for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < static_cast<int>(size); pos += gridDim.x * blockDim.x) {
66 const int index = pos / inner_size;
67 const int inner_pos = pos % inner_size;
68 const int grad_pos = pos;
69 const int cur_pos = indices[index] * inner_size + inner_pos;
70 accumulation[cur_pos] += gradient[grad_pos] * gradient[grad_pos];
71 const T scratch1 = learning_rate[0] * RsqrtFunc(accumulation[cur_pos]);
72 T prox_v = variable[cur_pos];
73 prox_v -= gradient[grad_pos] * scratch1;
74 const T scratch2 = AbsFunc(prox_v) - scratch1 * l1_regularization[0];
75 const T scratch3 = CompareFunc(scratch2, static_cast<T>(0.0)) ? scratch2 : static_cast<T>(0.0);
76 variable[cur_pos] = CompareFunc(l1_regularization[0], static_cast<T>(0.0)) ? Sgn(prox_v) * scratch3 : prox_v;
77 variable[cur_pos] = variable[cur_pos] / (static_cast<T>(1.0) + l2_regularization[0] * scratch1);
78 accumulation_out[cur_pos] = accumulation[cur_pos];
79 variable_out[cur_pos] = variable[cur_pos];
80 }
81 }
82
83 template <typename T>
CalSparseApplyProximalAdagrad(const size_t size,const size_t indices_size,const T * learning_rate,const T * l1_regularization,const T * l2_regularization,const T * gradient,const int * indices,T * variable,T * accumulation,T * variable_out,T * accumulation_out,cudaStream_t cuda_stream)84 void CalSparseApplyProximalAdagrad(const size_t size, const size_t indices_size, const T *learning_rate,
85 const T *l1_regularization, const T *l2_regularization, const T *gradient,
86 const int *indices, T *variable, T *accumulation, T *variable_out,
87 T *accumulation_out, cudaStream_t cuda_stream) {
88 SparseApplyProximalAdagradUpdate<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(
89 size, indices_size, learning_rate, l1_regularization, l2_regularization, gradient, indices, variable, accumulation,
90 variable_out, accumulation_out);
91 }
92
93 template void CalSparseApplyProximalAdagrad<float>(const size_t size, const size_t indices_size,
94 const float *learning_rate, const float *l1_regularization,
95 const float *l2_regularization, const float *gradient,
96 const int *indices, float *variable, float *accumulation,
97 float *variable_out, float *accumulation_out,
98 cudaStream_t cuda_stream);
99 template void CalSparseApplyProximalAdagrad<half>(const size_t size, const size_t indices_size,
100 const half *learning_rate, const half *l1_regularization,
101 const half *l2_regularization, const half *gradient,
102 const int *indices, half *variable, half *accumulation,
103 half *variable_out, half *accumulation_out, cudaStream_t cuda_stream);
104