• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/extendrt/delegate/tensorrt/cuda_impl/cublas_utils.h"
18 
19 namespace mindspore::lite {
Cublas2DTranspose(const float * in_addr,float * out_addr,const int * params,cublasHandle_t cublas_handle)20 void Cublas2DTranspose(const float *in_addr, float *out_addr, const int *params, cublasHandle_t cublas_handle) {
21   const int m = params[0];
22   const int n = params[1];
23   const float alpha = 1.0f;
24   const float beta = 0.0f;
25   CUBLAS_CHECK_VOID(
26     cublasSgeam(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, m, n, &alpha, in_addr, n, &beta, out_addr, m, out_addr, m));
27 }
28 
CublasMM1Batch(const void * a_addr,const void * b_addr,void * c_addr,const int * params,const cublasOperation_t * operations,const cudaDataType * data_types,cublasHandle_t cublas_handle)29 void CublasMM1Batch(const void *a_addr, const void *b_addr, void *c_addr, const int *params,
30                     const cublasOperation_t *operations, const cudaDataType *data_types, cublasHandle_t cublas_handle) {
31   const int m = params[0];
32   const int n = params[1];
33   const int k = params[2];
34   cublasOperation_t trans_a = operations[0];
35   cublasOperation_t trans_b = operations[1];
36   const int lda = (trans_a == CUBLAS_OP_N) ? k : m;
37   const int ldb = (trans_b == CUBLAS_OP_N) ? n : k;
38   const int ldc = n;
39   cudaDataType type_a = data_types[0];
40   cudaDataType type_b = data_types[1];
41   cudaDataType type_c = data_types[2];
42   cudaDataType compute_type = data_types[3];
43   const float alpha = 1.0f;
44   const float beta = 0.0f;
45   CUBLAS_CHECK_VOID(cublasGemmEx(cublas_handle, trans_b, trans_a, n, m, k, &alpha, b_addr, type_b, ldb, a_addr, type_a,
46                                  lda, &beta, c_addr, type_c, ldc, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
47 }
48 
CublasMMBatched(void ** a_addrs,void ** b_addrs,void ** c_addrs,const int * params,const cublasOperation_t * operations,const cudaDataType * data_types,cublasHandle_t cublas_handle)49 void CublasMMBatched(void **a_addrs, void **b_addrs, void **c_addrs, const int *params,
50                      const cublasOperation_t *operations, const cudaDataType *data_types,
51                      cublasHandle_t cublas_handle) {
52   cublasOperation_t trans_a = operations[0];
53   cublasOperation_t trans_b = operations[1];
54   const int m = params[0];
55   const int n = params[1];
56   const int k = params[2];
57   const int batch = params[3];
58   const int lda = (trans_a == CUBLAS_OP_N) ? k : m;
59   const int ldb = (trans_b == CUBLAS_OP_N) ? n : k;
60   const int ldc = n;
61   cudaDataType type_a = data_types[0];
62   cudaDataType type_b = data_types[1];
63   cudaDataType type_c = data_types[2];
64   cudaDataType compute_type = data_types[3];
65   const float alpha = 1.0f;
66   const float beta = 0.0f;
67   CUBLAS_CHECK_VOID(cublasGemmBatchedEx(cublas_handle, trans_b, trans_a, n, m, k, &alpha, b_addrs, type_b, ldb, a_addrs,
68                                         type_a, lda, &beta, c_addrs, type_c, ldc, batch, compute_type,
69                                         CUBLAS_GEMM_DEFAULT_TENSOR_OP));
70 }
71 
CublasGemmWrapper(const void * a_addr,const void * b_addr,void * c_addr,const int * params,const int * lds,const cublasOperation_t * operations,const cudaDataType * data_types,void * alpha,void * beta,cublasHandle_t cublas_handle,cublasGemmAlgo_t algo)72 void CublasGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds,
73                        const cublasOperation_t *operations, const cudaDataType *data_types, void *alpha, void *beta,
74                        cublasHandle_t cublas_handle, cublasGemmAlgo_t algo) {
75   const int m = params[0];
76   const int n = params[1];
77   const int k = params[2];
78   cublasOperation_t trans_a = operations[0];
79   cublasOperation_t trans_b = operations[1];
80   const int lda = lds[0];
81   const int ldb = lds[1];
82   const int ldc = lds[2];
83   cudaDataType type_a = data_types[0];
84   cudaDataType type_b = data_types[1];
85   cudaDataType type_c = data_types[2];
86   cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
87   if ((type_a == CUDA_R_16F) && (type_b == CUDA_R_16F) && (type_c == CUDA_R_16F)) {
88     compute_type = CUBLAS_COMPUTE_16F;
89   }
90   CUBLAS_CHECK_VOID(cublasGemmEx(cublas_handle, trans_a, trans_b, m, n, k, alpha, a_addr, type_a, lda, b_addr, type_b,
91                                  ldb, beta, c_addr, type_c, ldc, compute_type, algo));
92 }
93 
CublasGemmStridedBatchedWrapper(const void * a_addr,const void * b_addr,void * c_addr,const int * params,const int * lds,const cublasOperation_t * operations,const int * strides,const cudaDataType * data_types,void * alpha,void * beta,int batch,cublasHandle_t cublas_handle,cublasGemmAlgo_t algo)94 void CublasGemmStridedBatchedWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params,
95                                      const int *lds, const cublasOperation_t *operations, const int *strides,
96                                      const cudaDataType *data_types, void *alpha, void *beta, int batch,
97                                      cublasHandle_t cublas_handle, cublasGemmAlgo_t algo) {
98   const int m = params[0];
99   const int n = params[1];
100   const int k = params[2];
101   cublasOperation_t trans_a = operations[0];
102   cublasOperation_t trans_b = operations[1];
103   const int lda = lds[0];
104   const int ldb = lds[1];
105   const int ldc = lds[2];
106   cudaDataType type_a = data_types[0];
107   cudaDataType type_b = data_types[1];
108   cudaDataType type_c = data_types[2];
109   cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
110   if ((type_a == CUDA_R_16F) && (type_b == CUDA_R_16F) && (type_c == CUDA_R_16F)) {
111     compute_type = CUBLAS_COMPUTE_16F;
112   }
113   const int stride_a = strides[0];
114   const int stride_b = strides[1];
115   const int stride_c = strides[2];
116 
117   CUBLAS_CHECK_VOID(cublasGemmStridedBatchedEx(cublas_handle, trans_a, trans_b, m, n, k, alpha, a_addr, type_a, lda,
118                                                stride_a, b_addr, type_b, ldb, stride_b, beta, c_addr, type_c, ldc,
119                                                stride_c, batch, compute_type, algo));
120 }
121 
CublasLtGemmWrapper(const void * a_addr,const void * b_addr,void * c_addr,const int * params,const int * lds,const cublasOperation_t * operations,const cudaDataType * data_types,void * alpha,void * beta,const void * bias,cudaStream_t stream,cublasLtHandle_t cublaslt_handle)122 void CublasLtGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds,
123                          const cublasOperation_t *operations, const cudaDataType *data_types, void *alpha, void *beta,
124                          const void *bias, cudaStream_t stream, cublasLtHandle_t cublaslt_handle) {
125   cublasOperation_t trans_a = operations[0];
126   cublasOperation_t trans_b = operations[1];
127   cudaDataType type_a = data_types[0];
128   cudaDataType type_b = data_types[1];
129   cudaDataType type_c = data_types[2];
130   const int m = params[0];
131   const int n = params[1];
132   const int k = params[2];
133   const int lda = lds[0];
134   const int ldb = lds[1];
135   const int ldc = lds[2];
136 
137   cublasLtMatrixLayout_t mat_a_desc = NULL;
138   cublasLtMatrixLayoutCreate(&mat_a_desc, type_a, (trans_a == CUBLAS_OP_N) ? m : k, (trans_a == CUBLAS_OP_N) ? k : m,
139                              lda);
140   cublasLtMatrixLayout_t mat_b_desc = NULL;
141   cublasLtMatrixLayoutCreate(&mat_b_desc, type_b, (trans_b == CUBLAS_OP_N) ? k : n, (trans_b == CUBLAS_OP_N) ? n : k,
142                              ldb);
143   cublasLtMatrixLayout_t mat_c_desc = NULL;
144   cublasLtMatrixLayoutCreate(&mat_c_desc, type_c, m, n, ldc);
145 
146   cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
147   if ((type_a == CUDA_R_16F) && (type_b == CUDA_R_16F) && (type_c == CUDA_R_16F)) {
148     compute_type = CUBLAS_COMPUTE_16F;
149   }
150 
151   cublasLtMatmulDesc_t mat_operation_desc = NULL;
152   cublasLtMatmulDescCreate(&mat_operation_desc, compute_type, type_a);
153   cublasLtMatmulDescSetAttribute(mat_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(cublasOperation_t));
154   cublasLtMatmulDescSetAttribute(mat_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(cublasOperation_t));
155   if (bias != nullptr) {
156     cublasLtEpilogue_t epi = CUBLASLT_EPILOGUE_BIAS;
157     cublasLtMatmulDescSetAttribute(mat_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epi, sizeof(cublasLtEpilogue_t));
158     cublasLtMatmulDescSetAttribute(mat_operation_desc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(const void *));
159   }
160 
161   cublasLtMatmul(cublaslt_handle, mat_operation_desc, alpha, a_addr, mat_a_desc, b_addr, mat_b_desc, beta, c_addr,
162                  mat_c_desc, c_addr, mat_c_desc, NULL, NULL, 0, stream);
163   cublasLtMatrixLayoutDestroy(mat_a_desc);
164   cublasLtMatrixLayoutDestroy(mat_b_desc);
165   cublasLtMatrixLayoutDestroy(mat_c_desc);
166   cublasLtMatmulDescDestroy(mat_operation_desc);
167 }
168 }  // namespace mindspore::lite
169