1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_CUDA_IMPL_CUBLAS_UTILS_H_ 18 #define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_CUDA_IMPL_CUBLAS_UTILS_H_ 19 20 #include <cublasLt.h> 21 #include <cublas_v2.h> 22 #include "src/extendrt/delegate/tensorrt/cuda_impl/cuda_helper.h" 23 #include "src/common/log_util.h" 24 25 // cublas API error checking 26 #define CUBLAS_CHECK_VOID(err) \ 27 do { \ 28 cublasStatus_t cublas_err = (err); \ 29 if (cublas_err != CUBLAS_STATUS_SUCCESS) { \ 30 MS_LOG(ERROR) << "cublas error " << cublas_err; \ 31 return; \ 32 } \ 33 } while (0) 34 35 #define CUBLAS_CHECK(err) \ 36 do { \ 37 cublasStatus_t cublas_err = (err); \ 38 if (cublas_err != CUBLAS_STATUS_SUCCESS) { \ 39 MS_LOG(ERROR) << "cublas error " << cublas_err; \ 40 return -1; \ 41 } \ 42 } while (0) 43 44 namespace mindspore::lite { 45 // a: m * n 46 // params order: m, n 47 void Cublas2DTranspose(const float *in_addr, float *out_addr, const int *params, cublasHandle_t cublas_handle); 48 49 // a: m * k, b: k * n, c: m * n 50 // params order: m, n, k 51 // operations order: trans_a, trans_b 52 // data_types: type_a, type_b, type_c, compute type 53 void CublasMM1Batch(const void *a_addr, const void *b_addr, void *c_addr, const int *params, 54 const cublasOperation_t *operations, const cudaDataType *data_types, cublasHandle_t cublas_handle); 55 56 // a: batch * m * k, b: batch * k * n, c: batch * m * n 57 // params order: m, n, k, batch 58 // operations order: trans_a, trans_b 59 // data_types: type_a, type_b, type_c, compute type 60 void CublasMMBatched(void **a_addrs, void **b_addrs, void **c_addrs, const int *params, 61 const cublasOperation_t *operations, const cudaDataType *data_types, cublasHandle_t cublas_handle); 62 63 void CublasGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds, 64 const cublasOperation_t *operations, const cudaDataType *data_types, void *alpha, void *beta, 65 cublasHandle_t cublas_handle, cublasGemmAlgo_t = CUBLAS_GEMM_DEFAULT_TENSOR_OP); 66 void CublasGemmStridedBatchedWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, 67 const int *lds, const cublasOperation_t *operations, const int *strides, 68 const cudaDataType *data_types, void *alpha, void *beta, int batch, 69 cublasHandle_t cublas_handle, cublasGemmAlgo_t = CUBLAS_GEMM_DEFAULT_TENSOR_OP); 70 71 void CublasLtGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds, 72 const cublasOperation_t *operations, const cudaDataType *data_types, void *alpha, void *beta, 73 const void *bias, cudaStream_t stream, cublasLtHandle_t cublaslt_handle); 74 } // namespace mindspore::lite 75 #endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_CUDA_IMPL_CUBLAS_UTILS_H_ 76