1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <cuda_runtime.h>
17 #include "batchtospace_impl.cuh"
18 #include "include/cuda_fp16.h"
19
20 template <typename T>
BatchToSpace(const size_t size,const T * input,const size_t in,const size_t ih,const size_t iw,const size_t ic,const size_t on,const size_t oh,const size_t ow,const size_t oc,const size_t crop_up,const size_t crop_dn,const size_t crop_lft,const size_t crop_rht,const size_t block_num,T * output)21 __global__ void BatchToSpace(const size_t size, const T *input, const size_t in, const size_t ih, const size_t iw,
22 const size_t ic, const size_t on, const size_t oh, const size_t ow, const size_t oc,
23 const size_t crop_up, const size_t crop_dn, const size_t crop_lft, const size_t crop_rht,
24 const size_t block_num, T *output) {
25 size_t temp_stride = 0;
26 size_t temp_pos = 0;
27 size_t idx_on = 0;
28 size_t idx_oc = 0;
29 size_t idx_oh = 0;
30 size_t idx_ow = 0;
31 size_t idx_in = 0;
32 size_t input_pos = 0;
33 for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
34 temp_stride = oc * oh * ow;
35 idx_on = pos / temp_stride;
36 temp_pos = pos % temp_stride;
37
38 temp_stride /= oc;
39 idx_oc = temp_pos / temp_stride;
40 temp_pos = pos % temp_stride;
41
42 temp_stride /= oh;
43 idx_oh = temp_pos / temp_stride;
44 temp_pos = pos % temp_stride;
45
46 temp_stride /= ow;
47 idx_ow = temp_pos / temp_stride;
48
49 idx_in = (((idx_oh + crop_up) % block_num) * block_num + ((idx_ow + crop_lft) % block_num)) * on + idx_on;
50 input_pos = idx_in * ic;
51 input_pos = (input_pos + idx_oc) * ih;
52 input_pos = (input_pos + ((idx_oh + crop_up) - (idx_in / (on * block_num))) / block_num) * iw;
53 input_pos = (input_pos + ((idx_ow + crop_lft) - ((idx_in / on) % block_num)) / block_num);
54 output[pos] = input[input_pos];
55 }
56 return;
57 }
58
59 template <typename T>
CalBatchToSpace(const size_t size,const T * input,const size_t in,const size_t ih,const size_t iw,const size_t ic,const size_t on,const size_t oh,const size_t ow,const size_t oc,const size_t crop_up,const size_t crop_dn,const size_t crop_lft,const size_t crop_rht,const size_t block_num,T * output,const uint32_t & device_id,cudaStream_t cuda_stream)60 cudaError_t CalBatchToSpace(const size_t size, const T *input, const size_t in, const size_t ih, const size_t iw,
61 const size_t ic, const size_t on, const size_t oh, const size_t ow, const size_t oc,
62 const size_t crop_up, const size_t crop_dn, const size_t crop_lft, const size_t crop_rht,
63 const size_t block_num, T *output, const uint32_t &device_id, cudaStream_t cuda_stream) {
64 BatchToSpace<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
65 size, input, in, ih, iw, ic, on, oh, ow, oc, crop_up, crop_dn, crop_lft, crop_rht, block_num, output);
66 return GetCudaStatus();
67 }
68
69 template CUDA_LIB_EXPORT cudaError_t CalBatchToSpace<float>(const size_t size, const float *input, const size_t in,
70 const size_t ih, const size_t iw, const size_t ic,
71 const size_t on, const size_t oh, const size_t ow,
72 const size_t oc, const size_t crop_up, const size_t crop_dn,
73 const size_t crop_lft, const size_t crop_rht,
74 const size_t block_num, float *output,
75 const uint32_t &device_id, cudaStream_t cuda_stream);
76 template CUDA_LIB_EXPORT cudaError_t CalBatchToSpace<half>(const size_t size, const half *input, const size_t in,
77 const size_t ih, const size_t iw, const size_t ic,
78 const size_t on, const size_t oh, const size_t ow,
79 const size_t oc, const size_t crop_up, const size_t crop_dn,
80 const size_t crop_lft, const size_t crop_rht,
81 const size_t block_num, half *output,
82 const uint32_t &device_id, cudaStream_t cuda_stream);
83 template CUDA_LIB_EXPORT cudaError_t CalBatchToSpace<int>(const size_t size, const int *input, const size_t in,
84 const size_t ih, const size_t iw, const size_t ic,
85 const size_t on, const size_t oh, const size_t ow,
86 const size_t oc, const size_t crop_up, const size_t crop_dn,
87 const size_t crop_lft, const size_t crop_rht,
88 const size_t block_num, int *output,
89 const uint32_t &device_id, cudaStream_t cuda_stream);
90 template CUDA_LIB_EXPORT cudaError_t
91 CalBatchToSpace<int64_t>(const size_t size, const int64_t *input, const size_t in, const size_t ih, const size_t iw,
92 const size_t ic, const size_t on, const size_t oh, const size_t ow, const size_t oc,
93 const size_t crop_up, const size_t crop_dn, const size_t crop_lft, const size_t crop_rht,
94 const size_t block_num, int64_t *output, const uint32_t &device_id, cudaStream_t cuda_stream);
95 template CUDA_LIB_EXPORT cudaError_t
96 CalBatchToSpace<int16_t>(const size_t size, const int16_t *input, const size_t in, const size_t ih, const size_t iw,
97 const size_t ic, const size_t on, const size_t oh, const size_t ow, const size_t oc,
98 const size_t crop_up, const size_t crop_dn, const size_t crop_lft, const size_t crop_rht,
99 const size_t block_num, int16_t *output, const uint32_t &device_id, cudaStream_t cuda_stream);
100 template CUDA_LIB_EXPORT cudaError_t
101 CalBatchToSpace<int8_t>(const size_t size, const int8_t *input, const size_t in, const size_t ih, const size_t iw,
102 const size_t ic, const size_t on, const size_t oh, const size_t ow, const size_t oc,
103 const size_t crop_up, const size_t crop_dn, const size_t crop_lft, const size_t crop_rht,
104 const size_t block_num, int8_t *output, const uint32_t &device_id, cudaStream_t cuda_stream);
105 template CUDA_LIB_EXPORT cudaError_t
106 CalBatchToSpace<uint8_t>(const size_t size, const uint8_t *input, const size_t in, const size_t ih, const size_t iw,
107 const size_t ic, const size_t on, const size_t oh, const size_t ow, const size_t oc,
108 const size_t crop_up, const size_t crop_dn, const size_t crop_lft, const size_t crop_rht,
109 const size_t block_num, uint8_t *output, const uint32_t &device_id, cudaStream_t cuda_stream);
110 template CUDA_LIB_EXPORT cudaError_t CalBatchToSpace<uint16_t>(
111 const size_t size, const uint16_t *input, const size_t in, const size_t ih, const size_t iw, const size_t ic,
112 const size_t on, const size_t oh, const size_t ow, const size_t oc, const size_t crop_up, const size_t crop_dn,
113 const size_t crop_lft, const size_t crop_rht, const size_t block_num, uint16_t *output, const uint32_t &device_id,
114 cudaStream_t cuda_stream);
115 template CUDA_LIB_EXPORT cudaError_t CalBatchToSpace<uint32_t>(
116 const size_t size, const uint32_t *input, const size_t in, const size_t ih, const size_t iw, const size_t ic,
117 const size_t on, const size_t oh, const size_t ow, const size_t oc, const size_t crop_up, const size_t crop_dn,
118 const size_t crop_lft, const size_t crop_rht, const size_t block_num, uint32_t *output, const uint32_t &device_id,
119 cudaStream_t cuda_stream);
120 template CUDA_LIB_EXPORT cudaError_t CalBatchToSpace<uint64_t>(
121 const size_t size, const uint64_t *input, const size_t in, const size_t ih, const size_t iw, const size_t ic,
122 const size_t on, const size_t oh, const size_t ow, const size_t oc, const size_t crop_up, const size_t crop_dn,
123 const size_t crop_lft, const size_t crop_rht, const size_t block_num, uint64_t *output, const uint32_t &device_id,
124 cudaStream_t cuda_stream);
125