• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/extendrt/delegate/tensorrt/cuda_impl/cuda_helper.h"
18 #include <cmath>
19 #include "src/common/log_util.h"
20 
GetInstance()21 CudaHelper &CudaHelper::GetInstance() {
22   static CudaHelper instance;
23   return instance;
24 }
GetThreadNum() const25 int CudaHelper::GetThreadNum() const { return threads_per_block_; }
GetThreadNum(const int block_size) const26 int CudaHelper::GetThreadNum(const int block_size) const {
27   return std::min(threads_per_block_, ((block_size - 1) / 32 + 1) * 32);
28 }
GetBlocksNum(const int total_threads) const29 int CudaHelper::GetBlocksNum(const int total_threads) const {
30   return std::min(((total_threads - 1) / threads_per_block_) + 1, max_blocks_);
31 }
GetBlocksNum(const int total_threads,const int block_size) const32 int CudaHelper::GetBlocksNum(const int total_threads, const int block_size) const {
33   int valid_block_size = std::min(block_size, threads_per_block_);
34   if (valid_block_size == 0) {
35     MS_LOG(ERROR) << "invalid input of block_size: " << block_size;
36     return 0;
37   }
38   return std::min(((total_threads - 1) / valid_block_size) + 1, max_blocks_);
39 }
40 
CudaHelper()41 CudaHelper::CudaHelper() {
42   int device_id = 0;
43   (void)cudaGetDevice(&device_id);
44   cudaDeviceProp prop;
45   (void)cudaGetDeviceProperties(&prop, device_id);
46   threads_per_block_ = prop.maxThreadsPerBlock;
47   max_blocks_ = prop.multiProcessorCount;
48 }
49