• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_CUDA_IMPL_CUBLAS_UTILS_H_
18 #define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_CUDA_IMPL_CUBLAS_UTILS_H_
19 
20 #include <cublasLt.h>
21 #include <cublas_v2.h>
22 #include "src/extendrt/delegate/tensorrt/cuda_impl/cuda_helper.h"
23 #include "src/common/log_util.h"
24 
25 // cublas API error checking
26 #define CUBLAS_CHECK_VOID(err)                        \
27   do {                                                \
28     cublasStatus_t cublas_err = (err);                \
29     if (cublas_err != CUBLAS_STATUS_SUCCESS) {        \
30       MS_LOG(ERROR) << "cublas error " << cublas_err; \
31       return;                                         \
32     }                                                 \
33   } while (0)
34 
35 #define CUBLAS_CHECK(err)                             \
36   do {                                                \
37     cublasStatus_t cublas_err = (err);                \
38     if (cublas_err != CUBLAS_STATUS_SUCCESS) {        \
39       MS_LOG(ERROR) << "cublas error " << cublas_err; \
40       return -1;                                      \
41     }                                                 \
42   } while (0)
43 
44 namespace mindspore::lite {
45 // a: m * n
46 // params order: m, n
47 void Cublas2DTranspose(const float *in_addr, float *out_addr, const int *params, cublasHandle_t cublas_handle);
48 
49 // a: m * k, b: k * n, c: m * n
50 // params order: m, n, k
51 // operations order: trans_a, trans_b
52 // data_types: type_a, type_b, type_c, compute type
53 void CublasMM1Batch(const void *a_addr, const void *b_addr, void *c_addr, const int *params,
54                     const cublasOperation_t *operations, const cudaDataType *data_types, cublasHandle_t cublas_handle);
55 
56 // a: batch * m * k, b: batch * k * n, c: batch * m * n
57 // params order: m, n, k, batch
58 // operations order: trans_a, trans_b
59 // data_types: type_a, type_b, type_c, compute type
60 void CublasMMBatched(void **a_addrs, void **b_addrs, void **c_addrs, const int *params,
61                      const cublasOperation_t *operations, const cudaDataType *data_types, cublasHandle_t cublas_handle);
62 
63 void CublasGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds,
64                        const cublasOperation_t *operations, const cudaDataType *data_types, void *alpha, void *beta,
65                        cublasHandle_t cublas_handle, cublasGemmAlgo_t = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
66 void CublasGemmStridedBatchedWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params,
67                                      const int *lds, const cublasOperation_t *operations, const int *strides,
68                                      const cudaDataType *data_types, void *alpha, void *beta, int batch,
69                                      cublasHandle_t cublas_handle, cublasGemmAlgo_t = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
70 
71 void CublasLtGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds,
72                          const cublasOperation_t *operations, const cudaDataType *data_types, void *alpha, void *beta,
73                          const void *bias, cudaStream_t stream, cublasLtHandle_t cublaslt_handle);
74 }  // namespace mindspore::lite
75 #endif  // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_CUDA_IMPL_CUBLAS_UTILS_H_
76