1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/extendrt/delegate/tensorrt/cuda_impl/cublas_utils.h"
18
19 namespace mindspore::lite {
Cublas2DTranspose(const float * in_addr,float * out_addr,const int * params,cublasHandle_t cublas_handle)20 void Cublas2DTranspose(const float *in_addr, float *out_addr, const int *params, cublasHandle_t cublas_handle) {
21 const int m = params[0];
22 const int n = params[1];
23 const float alpha = 1.0f;
24 const float beta = 0.0f;
25 CUBLAS_CHECK_VOID(
26 cublasSgeam(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, m, n, &alpha, in_addr, n, &beta, out_addr, m, out_addr, m));
27 }
28
CublasMM1Batch(const void * a_addr,const void * b_addr,void * c_addr,const int * params,const cublasOperation_t * operations,const cudaDataType * data_types,cublasHandle_t cublas_handle)29 void CublasMM1Batch(const void *a_addr, const void *b_addr, void *c_addr, const int *params,
30 const cublasOperation_t *operations, const cudaDataType *data_types, cublasHandle_t cublas_handle) {
31 const int m = params[0];
32 const int n = params[1];
33 const int k = params[2];
34 cublasOperation_t trans_a = operations[0];
35 cublasOperation_t trans_b = operations[1];
36 const int lda = (trans_a == CUBLAS_OP_N) ? k : m;
37 const int ldb = (trans_b == CUBLAS_OP_N) ? n : k;
38 const int ldc = n;
39 cudaDataType type_a = data_types[0];
40 cudaDataType type_b = data_types[1];
41 cudaDataType type_c = data_types[2];
42 cudaDataType compute_type = data_types[3];
43 const float alpha = 1.0f;
44 const float beta = 0.0f;
45 CUBLAS_CHECK_VOID(cublasGemmEx(cublas_handle, trans_b, trans_a, n, m, k, &alpha, b_addr, type_b, ldb, a_addr, type_a,
46 lda, &beta, c_addr, type_c, ldc, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
47 }
48
CublasMMBatched(void ** a_addrs,void ** b_addrs,void ** c_addrs,const int * params,const cublasOperation_t * operations,const cudaDataType * data_types,cublasHandle_t cublas_handle)49 void CublasMMBatched(void **a_addrs, void **b_addrs, void **c_addrs, const int *params,
50 const cublasOperation_t *operations, const cudaDataType *data_types,
51 cublasHandle_t cublas_handle) {
52 cublasOperation_t trans_a = operations[0];
53 cublasOperation_t trans_b = operations[1];
54 const int m = params[0];
55 const int n = params[1];
56 const int k = params[2];
57 const int batch = params[3];
58 const int lda = (trans_a == CUBLAS_OP_N) ? k : m;
59 const int ldb = (trans_b == CUBLAS_OP_N) ? n : k;
60 const int ldc = n;
61 cudaDataType type_a = data_types[0];
62 cudaDataType type_b = data_types[1];
63 cudaDataType type_c = data_types[2];
64 cudaDataType compute_type = data_types[3];
65 const float alpha = 1.0f;
66 const float beta = 0.0f;
67 CUBLAS_CHECK_VOID(cublasGemmBatchedEx(cublas_handle, trans_b, trans_a, n, m, k, &alpha, b_addrs, type_b, ldb, a_addrs,
68 type_a, lda, &beta, c_addrs, type_c, ldc, batch, compute_type,
69 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
70 }
71
CublasGemmWrapper(const void * a_addr,const void * b_addr,void * c_addr,const int * params,const int * lds,const cublasOperation_t * operations,const cudaDataType * data_types,void * alpha,void * beta,cublasHandle_t cublas_handle,cublasGemmAlgo_t algo)72 void CublasGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds,
73 const cublasOperation_t *operations, const cudaDataType *data_types, void *alpha, void *beta,
74 cublasHandle_t cublas_handle, cublasGemmAlgo_t algo) {
75 const int m = params[0];
76 const int n = params[1];
77 const int k = params[2];
78 cublasOperation_t trans_a = operations[0];
79 cublasOperation_t trans_b = operations[1];
80 const int lda = lds[0];
81 const int ldb = lds[1];
82 const int ldc = lds[2];
83 cudaDataType type_a = data_types[0];
84 cudaDataType type_b = data_types[1];
85 cudaDataType type_c = data_types[2];
86 cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
87 if ((type_a == CUDA_R_16F) && (type_b == CUDA_R_16F) && (type_c == CUDA_R_16F)) {
88 compute_type = CUBLAS_COMPUTE_16F;
89 }
90 CUBLAS_CHECK_VOID(cublasGemmEx(cublas_handle, trans_a, trans_b, m, n, k, alpha, a_addr, type_a, lda, b_addr, type_b,
91 ldb, beta, c_addr, type_c, ldc, compute_type, algo));
92 }
93
CublasGemmStridedBatchedWrapper(const void * a_addr,const void * b_addr,void * c_addr,const int * params,const int * lds,const cublasOperation_t * operations,const int * strides,const cudaDataType * data_types,void * alpha,void * beta,int batch,cublasHandle_t cublas_handle,cublasGemmAlgo_t algo)94 void CublasGemmStridedBatchedWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params,
95 const int *lds, const cublasOperation_t *operations, const int *strides,
96 const cudaDataType *data_types, void *alpha, void *beta, int batch,
97 cublasHandle_t cublas_handle, cublasGemmAlgo_t algo) {
98 const int m = params[0];
99 const int n = params[1];
100 const int k = params[2];
101 cublasOperation_t trans_a = operations[0];
102 cublasOperation_t trans_b = operations[1];
103 const int lda = lds[0];
104 const int ldb = lds[1];
105 const int ldc = lds[2];
106 cudaDataType type_a = data_types[0];
107 cudaDataType type_b = data_types[1];
108 cudaDataType type_c = data_types[2];
109 cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
110 if ((type_a == CUDA_R_16F) && (type_b == CUDA_R_16F) && (type_c == CUDA_R_16F)) {
111 compute_type = CUBLAS_COMPUTE_16F;
112 }
113 const int stride_a = strides[0];
114 const int stride_b = strides[1];
115 const int stride_c = strides[2];
116
117 CUBLAS_CHECK_VOID(cublasGemmStridedBatchedEx(cublas_handle, trans_a, trans_b, m, n, k, alpha, a_addr, type_a, lda,
118 stride_a, b_addr, type_b, ldb, stride_b, beta, c_addr, type_c, ldc,
119 stride_c, batch, compute_type, algo));
120 }
121
CublasLtGemmWrapper(const void * a_addr,const void * b_addr,void * c_addr,const int * params,const int * lds,const cublasOperation_t * operations,const cudaDataType * data_types,void * alpha,void * beta,const void * bias,cudaStream_t stream,cublasLtHandle_t cublaslt_handle)122 void CublasLtGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds,
123 const cublasOperation_t *operations, const cudaDataType *data_types, void *alpha, void *beta,
124 const void *bias, cudaStream_t stream, cublasLtHandle_t cublaslt_handle) {
125 cublasOperation_t trans_a = operations[0];
126 cublasOperation_t trans_b = operations[1];
127 cudaDataType type_a = data_types[0];
128 cudaDataType type_b = data_types[1];
129 cudaDataType type_c = data_types[2];
130 const int m = params[0];
131 const int n = params[1];
132 const int k = params[2];
133 const int lda = lds[0];
134 const int ldb = lds[1];
135 const int ldc = lds[2];
136
137 cublasLtMatrixLayout_t mat_a_desc = NULL;
138 cublasLtMatrixLayoutCreate(&mat_a_desc, type_a, (trans_a == CUBLAS_OP_N) ? m : k, (trans_a == CUBLAS_OP_N) ? k : m,
139 lda);
140 cublasLtMatrixLayout_t mat_b_desc = NULL;
141 cublasLtMatrixLayoutCreate(&mat_b_desc, type_b, (trans_b == CUBLAS_OP_N) ? k : n, (trans_b == CUBLAS_OP_N) ? n : k,
142 ldb);
143 cublasLtMatrixLayout_t mat_c_desc = NULL;
144 cublasLtMatrixLayoutCreate(&mat_c_desc, type_c, m, n, ldc);
145
146 cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
147 if ((type_a == CUDA_R_16F) && (type_b == CUDA_R_16F) && (type_c == CUDA_R_16F)) {
148 compute_type = CUBLAS_COMPUTE_16F;
149 }
150
151 cublasLtMatmulDesc_t mat_operation_desc = NULL;
152 cublasLtMatmulDescCreate(&mat_operation_desc, compute_type, type_a);
153 cublasLtMatmulDescSetAttribute(mat_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(cublasOperation_t));
154 cublasLtMatmulDescSetAttribute(mat_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(cublasOperation_t));
155 if (bias != nullptr) {
156 cublasLtEpilogue_t epi = CUBLASLT_EPILOGUE_BIAS;
157 cublasLtMatmulDescSetAttribute(mat_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epi, sizeof(cublasLtEpilogue_t));
158 cublasLtMatmulDescSetAttribute(mat_operation_desc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(const void *));
159 }
160
161 cublasLtMatmul(cublaslt_handle, mat_operation_desc, alpha, a_addr, mat_a_desc, b_addr, mat_b_desc, beta, c_addr,
162 mat_c_desc, c_addr, mat_c_desc, NULL, NULL, 0, stream);
163 cublasLtMatrixLayoutDestroy(mat_a_desc);
164 cublasLtMatrixLayoutDestroy(mat_b_desc);
165 cublasLtMatrixLayoutDestroy(mat_c_desc);
166 cublasLtMatmulDescDestroy(mat_operation_desc);
167 }
168 } // namespace mindspore::lite
169