• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "triplet_margin_loss_impl.cuh"
18 #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
19 
Index(const int64_t & index,const int64_t & dim)20 __device__ __forceinline__ int64_t Index(const int64_t &index, const int64_t &dim) { return dim == 1 ? 0 : index; }
21 
22 template <typename T>
FillAndBroadcast(const int64_t size,const size_t shape_size,const int64_t * tensor_shapes,const int64_t * dst_shape,const T * anchor,const T * positive,const T * negative,T * anchor_broadcast)23 __global__ void FillAndBroadcast(const int64_t size, const size_t shape_size, const int64_t *tensor_shapes,
24                                  const int64_t *dst_shape, const T *anchor, const T *positive, const T *negative,
25                                  T *anchor_broadcast) {
26   const T *pair_tensor[3] = {anchor, positive, negative};
27   for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < 3 * size; pos += blockDim.x * gridDim.x) {
28     const size_t mode = pos / size;
29     const int64_t *src_shape = tensor_shapes + shape_size * mode;
30     size_t tmp_pos = pos % size;
31     size_t pos_size = size / dst_shape[0];
32     size_t dst_index_array[8];
33     dst_index_array[0] = tmp_pos / pos_size;
34     for (size_t i = 1; i < shape_size; i++) {
35       tmp_pos -= dst_index_array[i - 1] * pos_size;
36       pos_size = pos_size / dst_shape[i];
37       dst_index_array[i] = tmp_pos / pos_size;
38     }
39     size_t src_size = 1;
40     for (size_t i = 0; i < shape_size; i++) {
41       src_size *= src_shape[i];
42     }
43     size_t src_pos = 0;
44     for (size_t i = 0; i < shape_size; i++) {
45       src_size /= src_shape[i];
46       size_t length_by_index = Index(dst_index_array[i], src_shape[i]) * src_size;
47       src_pos += length_by_index;
48     }
49     (anchor_broadcast + mode * size)[pos % size] = pair_tensor[mode][src_pos];
50   }
51   return;
52 }
53 
54 template <typename T>
PairwiseDistance(const T * anchor,const T * positive,const T * negative,const size_t * bound_list,const size_t bound,const size_t outer_size,const size_t inner_size,float * tem_output,const size_t n,const int64_t p,const float eps)55 __global__ void PairwiseDistance(const T *anchor, const T *positive, const T *negative, const size_t *bound_list,
56                                  const size_t bound, const size_t outer_size, const size_t inner_size,
57                                  float *tem_output, const size_t n, const int64_t p, const float eps) {
58   const T *pair_tensor[3] = {anchor, positive, negative};
59   for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < n * outer_size * inner_size;
60        pos += gridDim.x * blockDim.x) {
61     size_t mode = pos / (outer_size * inner_size);
62     size_t idx = pos % (outer_size * inner_size);
63     float res = 0;
64     size_t x = idx / inner_size % outer_size;
65     size_t y = idx % inner_size;
66     for (int i = 0; i < bound_list[mode]; i++) {
67       size_t input_offset = x * bound * inner_size + i * inner_size + y;
68       float base =
69         abs(static_cast<T>(pair_tensor[mode / 2][input_offset] - pair_tensor[(mode + 3) / 2][input_offset]) + eps);
70       float tem = pow(base, static_cast<float>(p));
71       res += tem;
72     }
73     tem_output[pos] = pow(res, static_cast<float>(1.0 / p));
74   }
75   return;
76 }
77 
PairwiseDistancePzero(const size_t * bound_list,const size_t output_size,float * tem_output,const size_t n)78 __global__ void PairwiseDistancePzero(const size_t *bound_list, const size_t output_size, float *tem_output,
79                                       const size_t n) {
80   for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < n * output_size; pos += gridDim.x * blockDim.x) {
81     size_t mode = pos / output_size;
82     tem_output[pos] = static_cast<float>(bound_list[mode]);
83   }
84   return;
85 }
86 
SwapTrue(float * tem_output,const size_t output_size)87 __global__ void SwapTrue(float *tem_output, const size_t output_size) {
88   for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_size; pos += gridDim.x * blockDim.x) {
89     tem_output[pos + output_size] = tem_output[pos + output_size] > tem_output[pos + 2 * output_size]
90                                       ? tem_output[pos + 2 * output_size]
91                                       : tem_output[pos + output_size];
92   }
93   return;
94 }
95 
MaxReduction(float * tem_output,float * output,const size_t output_size,const float * margin)96 __global__ void MaxReduction(float *tem_output, float *output, const size_t output_size, const float *margin) {
97   for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_size; pos += gridDim.x * blockDim.x) {
98     output[pos] = max(static_cast<float>(margin[0]) + tem_output[pos] - tem_output[pos + output_size], 0.0);
99   }
100   return;
101 }
102 
AddTile(float * tmp_loss,size_t index)103 __global__ void AddTile(float *tmp_loss, size_t index) { tmp_loss[0] += tmp_loss[index]; }
104 
PartialSum(float * tmp_loss,size_t stride)105 __global__ void PartialSum(float *tmp_loss, size_t stride) {
106   for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < stride; i += blockDim.x * gridDim.x) {
107     tmp_loss[i] += tmp_loss[i + stride];
108   }
109 }
110 
111 template <typename S>
ReductionDivde(S * output,float * tem_output,const size_t k)112 __global__ void ReductionDivde(S *output, float *tem_output, const size_t k) {
113   output[0] = tem_output[0] / k;
114 }
115 
116 // double
117 template <>
PairwiseDistance(const double * anchor,const double * positive,const double * negative,const size_t * bound_list,const size_t bound,const size_t outer_size,const size_t inner_size,float * tem_output,const size_t n,const int64_t p,const float eps)118 __global__ void PairwiseDistance(const double *anchor, const double *positive, const double *negative,
119                                  const size_t *bound_list, const size_t bound, const size_t outer_size,
120                                  const size_t inner_size, float *tem_output, const size_t n, const int64_t p,
121                                  const float eps) {
122   const double *pair_tensor[3] = {anchor, positive, negative};
123   for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < n * outer_size * inner_size;
124        pos += gridDim.x * blockDim.x) {
125     size_t mode = pos / (outer_size * inner_size);
126     size_t idx = pos % (outer_size * inner_size);
127     double res = 0;
128     size_t x = idx / inner_size % outer_size;
129     size_t y = idx % inner_size;
130     for (int i = 0; i < bound_list[mode]; i++) {
131       size_t input_offset = x * bound * inner_size + i * inner_size + y;
132       double base =
133         abs(static_cast<double>(pair_tensor[mode / 2][input_offset] - pair_tensor[(mode + 3) / 2][input_offset]) + eps);
134       double tem = pow(base, static_cast<double>(p));
135       res += tem;
136     }
137     tem_output[pos] = pow(res, static_cast<double>(1.0 / p));
138   }
139   return;
140 }
141 
142 // half
143 template <>
PairwiseDistance(const half * anchor,const half * positive,const half * negative,const size_t * bound_list,const size_t bound,const size_t outer_size,const size_t inner_size,float * tem_output,const size_t n,const int64_t p,const float eps)144 __global__ void PairwiseDistance(const half *anchor, const half *positive, const half *negative,
145                                  const size_t *bound_list, const size_t bound, const size_t outer_size,
146                                  const size_t inner_size, float *tem_output, const size_t n, const int64_t p,
147                                  const float eps) {
148   const half *pair_tensor[3] = {anchor, positive, negative};
149   for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < n * outer_size * inner_size;
150        pos += gridDim.x * blockDim.x) {
151     size_t mode = pos / (outer_size * inner_size);
152     size_t idx = pos % (outer_size * inner_size);
153     float res = 0;
154     size_t x = idx / inner_size % outer_size;
155     size_t y = idx % inner_size;
156     for (int i = 0; i < bound_list[mode]; i++) {
157       size_t input_offset = x * bound * inner_size + i * inner_size + y;
158       float base = abs(__half2float(pair_tensor[mode / 2][input_offset]) -
159                        __half2float(pair_tensor[(mode + 3) / 2][input_offset]) + eps);
160       float tem = pow(base, static_cast<float>(p));
161       res += tem;
162     }
163     tem_output[pos] = pow(res, static_cast<float>(1.0 / p));
164   }
165   return;
166 }
167 
168 // half
MaxReduction(float * tem_output,half * output,const size_t output_size,const float * margin)169 __global__ void MaxReduction(float *tem_output, half *output, const size_t output_size, const float *margin) {
170   float lower_bound = 0;
171   for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_size; pos += gridDim.x * blockDim.x) {
172     output[pos] = __float2half(max(margin[0] + tem_output[pos] - tem_output[pos + output_size], lower_bound));
173   }
174   return;
175 }
176 
177 // half
178 template <>
ReductionDivde(half * output,float * tem_output,const size_t k)179 __global__ void ReductionDivde(half *output, float *tem_output, const size_t k) {
180   output[0] = __float2half((tem_output[0] / k));
181 }
182 
183 // Complex
184 template <typename S>
PairwiseDistance(const Complex<S> * anchor,const Complex<S> * positive,const Complex<S> * negative,const size_t * bound_list,const size_t bound,const size_t outer_size,const size_t inner_size,float * tem_output,const size_t n,const int64_t p,const float eps)185 __global__ void PairwiseDistance(const Complex<S> *anchor, const Complex<S> *positive, const Complex<S> *negative,
186                                  const size_t *bound_list, const size_t bound, const size_t outer_size,
187                                  const size_t inner_size, float *tem_output, const size_t n, const int64_t p,
188                                  const float eps) {
189   const Complex<S> *pair_tensor[3] = {anchor, positive, negative};
190   for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < n * outer_size * inner_size;
191        pos += gridDim.x * blockDim.x) {
192     size_t mode = pos / (outer_size * inner_size);
193     size_t idx = pos % (outer_size * inner_size);
194     S res = 0;
195     size_t x = idx / inner_size % outer_size;
196     size_t y = idx % inner_size;
197     for (int i = 0; i < bound_list[mode]; i++) {
198       size_t input_offset = x * bound * inner_size + i * inner_size + y;
199       Complex<S> base_complex =
200         pair_tensor[mode / 2][input_offset] - pair_tensor[(mode + 3) / 2][input_offset] + static_cast<S>(eps);
201       S base = sqrt((base_complex.real() * base_complex.real() + base_complex.imag() * base_complex.imag()));
202       S tem = pow(base, static_cast<S>(p));
203       res += tem;
204     }
205     tem_output[pos] = pow(res, static_cast<S>(1.0 / p));
206   }
207   return;
208 }
209 
210 template <typename T, typename S>
CalTripletMarginLoss(const T * anchor,const T * positive,const T * negative,T * anchor_broadcast,T * positive_broadcast,T * negative_broadcast,S * output,float * tem_output,const int64_t * tensor_shapes,const int64_t * dst_shape,const size_t outer_size,const size_t inner_size,const size_t * bound_list,const size_t bound,const size_t shape_size,float * margin,const int64_t p,const float eps,const std::string reduction,const bool swap,const bool need_broadcast,const uint32_t & device_id,cudaStream_t cuda_stream)211 cudaError_t CalTripletMarginLoss(const T *anchor, const T *positive, const T *negative, T *anchor_broadcast,
212                                  T *positive_broadcast, T *negative_broadcast, S *output, float *tem_output,
213                                  const int64_t *tensor_shapes, const int64_t *dst_shape, const size_t outer_size,
214                                  const size_t inner_size, const size_t *bound_list, const size_t bound,
215                                  const size_t shape_size, float *margin, const int64_t p, const float eps,
216                                  const std::string reduction, const bool swap, const bool need_broadcast,
217                                  const uint32_t &device_id, cudaStream_t cuda_stream) {
218   const int64_t size = outer_size * inner_size * bound;
219   size_t n = swap ? 3 : 2;
220   const size_t output_size = outer_size * inner_size;
221   size_t block_num = 256 > n * output_size ? n * output_size : 256;
222   if (p == 0) {
223     PairwiseDistancePzero<<<CUDA_BLOCKS_CAL(device_id, n * output_size, block_num), block_num, 0, cuda_stream>>>(
224       bound_list, output_size, tem_output, n);
225   } else if (need_broadcast) {
226     block_num = 256 > 3 * size ? 3 * size : 256;
227     FillAndBroadcast<<<CUDA_BLOCKS_CAL(device_id, 3 * size, block_num), block_num, 0, cuda_stream>>>(
228       size, shape_size, tensor_shapes, dst_shape, anchor, positive, negative, anchor_broadcast);
229     block_num = 256 > n * output_size ? n * output_size : 256;
230     PairwiseDistance<<<CUDA_BLOCKS_CAL(device_id, n * output_size, block_num), block_num, 0, cuda_stream>>>(
231       anchor_broadcast, positive_broadcast, negative_broadcast, bound_list, bound, outer_size, inner_size, tem_output,
232       n, p, eps);
233   } else {
234     PairwiseDistance<<<CUDA_BLOCKS_CAL(device_id, n * output_size, block_num), block_num, 0, cuda_stream>>>(
235       anchor, positive, negative, bound_list, bound, outer_size, inner_size, tem_output, n, p, eps);
236   }
237   block_num = 256 > output_size ? output_size : 256;
238   if (swap) {
239     SwapTrue<<<CUDA_BLOCKS_CAL(device_id, output_size, block_num), block_num, 0, cuda_stream>>>(tem_output,
240                                                                                                 output_size);
241   }
242   if (reduction == "none") {
243     MaxReduction<<<CUDA_BLOCKS_CAL(device_id, output_size, block_num), block_num, 0, cuda_stream>>>(
244       tem_output, output, output_size, margin);
245   } else {
246     MaxReduction<<<CUDA_BLOCKS_CAL(device_id, output_size, block_num), block_num, 0, cuda_stream>>>(
247       tem_output, tem_output, output_size, margin);
248     if (output_size % 2 == 1 && output_size != 1) {
249       AddTile<<<1, 1, 0, cuda_stream>>>(tem_output, output_size - 1);
250     }
251     for (size_t stride = output_size / 2; stride > 0; stride >>= 1) {
252       block_num = 256 > stride ? stride : 256;
253       PartialSum<<<CUDA_BLOCKS_CAL(device_id, stride, block_num), block_num, 0, cuda_stream>>>(tem_output, stride);
254       if (stride > 2 && stride % 2 == 1) {
255         AddTile<<<1, 1, 0, cuda_stream>>>(tem_output, stride - 1);
256       }
257     }
258     if (reduction == "mean") {
259       ReductionDivde<<<1, 1, 0, cuda_stream>>>(output, tem_output, output_size);
260     } else {
261       ReductionDivde<<<1, 1, 0, cuda_stream>>>(output, tem_output, 1);
262     }
263   }
264   return GetCudaStatus();
265 }
266 
267 template CUDA_LIB_EXPORT cudaError_t CalTripletMarginLoss<int8_t, float>(
268   const int8_t *anchor, const int8_t *positive, const int8_t *negative, int8_t *anchor_broadcast,
269   int8_t *positive_broadcast, int8_t *negative_broadcast, float *output, float *tem_output,
270   const int64_t *tensor_shapes, const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
271   const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin, const int64_t p,
272   const float eps, const std::string reduction, const bool swap, const bool need_broadcast, const uint32_t &device_id,
273   cudaStream_t cuda_stream);
274 template CUDA_LIB_EXPORT cudaError_t CalTripletMarginLoss<int16_t, float>(
275   const int16_t *anchor, const int16_t *positive, const int16_t *negative, int16_t *anchor_broadcast,
276   int16_t *positive_broadcast, int16_t *negative_broadcast, float *output, float *tem_output,
277   const int64_t *tensor_shapes, const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
278   const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin, const int64_t p,
279   const float eps, const std::string reduction, const bool swap, const bool need_broadcast, const uint32_t &device_id,
280   cudaStream_t cuda_stream);
281 template CUDA_LIB_EXPORT cudaError_t CalTripletMarginLoss<int32_t, float>(
282   const int32_t *anchor, const int32_t *positive, const int32_t *negative, int32_t *anchor_broadcast,
283   int32_t *positive_broadcast, int32_t *negative_broadcast, float *output, float *tem_output,
284   const int64_t *tensor_shapes, const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
285   const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin, const int64_t p,
286   const float eps, const std::string reduction, const bool swap, const bool need_broadcast, const uint32_t &device_id,
287   cudaStream_t cuda_stream);
288 template CUDA_LIB_EXPORT cudaError_t CalTripletMarginLoss<int64_t, float>(
289   const int64_t *anchor, const int64_t *positive, const int64_t *negative, int64_t *anchor_broadcast,
290   int64_t *positive_broadcast, int64_t *negative_broadcast, float *output, float *tem_output,
291   const int64_t *tensor_shapes, const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
292   const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin, const int64_t p,
293   const float eps, const std::string reduction, const bool swap, const bool need_broadcast, const uint32_t &device_id,
294   cudaStream_t cuda_stream);
295 template CUDA_LIB_EXPORT cudaError_t CalTripletMarginLoss<uint8_t, float>(
296   const uint8_t *anchor, const uint8_t *positive, const uint8_t *negative, uint8_t *anchor_broadcast,
297   uint8_t *positive_broadcast, uint8_t *negative_broadcast, float *output, float *tem_output,
298   const int64_t *tensor_shapes, const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
299   const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin, const int64_t p,
300   const float eps, const std::string reduction, const bool swap, const bool need_broadcast, const uint32_t &device_id,
301   cudaStream_t cuda_stream);
302 template CUDA_LIB_EXPORT cudaError_t CalTripletMarginLoss<uint16_t, float>(
303   const uint16_t *anchor, const uint16_t *positive, const uint16_t *negative, uint16_t *anchor_broadcast,
304   uint16_t *positive_broadcast, uint16_t *negative_broadcast, float *output, float *tem_output,
305   const int64_t *tensor_shapes, const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
306   const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin, const int64_t p,
307   const float eps, const std::string reduction, const bool swap, const bool need_broadcast, const uint32_t &device_id,
308   cudaStream_t cuda_stream);
309 template CUDA_LIB_EXPORT cudaError_t CalTripletMarginLoss<uint32_t, float>(
310   const uint32_t *anchor, const uint32_t *positive, const uint32_t *negative, uint32_t *anchor_broadcast,
311   uint32_t *positive_broadcast, uint32_t *negative_broadcast, float *output, float *tem_output,
312   const int64_t *tensor_shapes, const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
313   const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin, const int64_t p,
314   const float eps, const std::string reduction, const bool swap, const bool need_broadcast, const uint32_t &device_id,
315   cudaStream_t cuda_stream);
316 template CUDA_LIB_EXPORT cudaError_t CalTripletMarginLoss<uint64_t, float>(
317   const uint64_t *anchor, const uint64_t *positive, const uint64_t *negative, uint64_t *anchor_broadcast,
318   uint64_t *positive_broadcast, uint64_t *negative_broadcast, float *output, float *tem_output,
319   const int64_t *tensor_shapes, const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
320   const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin, const int64_t p,
321   const float eps, const std::string reduction, const bool swap, const bool need_broadcast, const uint32_t &device_id,
322   cudaStream_t cuda_stream);
323 template CUDA_LIB_EXPORT cudaError_t CalTripletMarginLoss<double, float>(
324   const double *anchor, const double *positive, const double *negative, double *anchor_broadcast,
325   double *positive_broadcast, double *negative_broadcast, float *output, float *tem_output,
326   const int64_t *tensor_shapes, const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
327   const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin, const int64_t p,
328   const float eps, const std::string reduction, const bool swap, const bool need_broadcast, const uint32_t &device_id,
329   cudaStream_t cuda_stream);
330 template CUDA_LIB_EXPORT cudaError_t CalTripletMarginLoss<float, float>(
331   const float *anchor, const float *positive, const float *negative, float *anchor_broadcast, float *positive_broadcast,
332   float *negative_broadcast, float *output, float *tem_output, const int64_t *tensor_shapes, const int64_t *dst_shape,
333   const size_t outer_size, const size_t inner_size, const size_t *bound_list, const size_t bound,
334   const size_t shape_size, float *margin, const int64_t p, const float eps, const std::string reduction,
335   const bool swap, const bool need_broadcast, const uint32_t &device_id, cudaStream_t cuda_stream);
336 
337 template CUDA_LIB_EXPORT cudaError_t CalTripletMarginLoss<half, half>(
338   const half *anchor, const half *positive, const half *negative, half *anchor_broadcast, half *positive_broadcast,
339   half *negative_broadcast, half *output, float *tem_output, const int64_t *tensor_shapes, const int64_t *dst_shape,
340   const size_t outer_size, const size_t inner_size, const size_t *bound_list, const size_t bound,
341   const size_t shape_size, float *margin, const int64_t p, const float eps, const std::string reduction,
342   const bool swap, const bool need_broadcast, const uint32_t &device_id, cudaStream_t cuda_stream);
343 
344 template CUDA_LIB_EXPORT cudaError_t CalTripletMarginLoss<Complex<float>, float>(
345   const Complex<float> *anchor, const Complex<float> *positive, const Complex<float> *negative,
346   Complex<float> *anchor_broadcast, Complex<float> *positive_broadcast, Complex<float> *negative_broadcast,
347   float *output, float *tem_output, const int64_t *tensor_shapes, const int64_t *dst_shape, const size_t outer_size,
348   const size_t inner_size, const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin,
349   const int64_t p, const float eps, const std::string reduction, const bool swap, const bool need_broadcast,
350   const uint32_t &device_id, cudaStream_t cuda_stream);
351 
352 template CUDA_LIB_EXPORT cudaError_t CalTripletMarginLoss<Complex<double>, float>(
353   const Complex<double> *anchor, const Complex<double> *positive, const Complex<double> *negative,
354   Complex<double> *anchor_broadcast, Complex<double> *positive_broadcast, Complex<double> *negative_broadcast,
355   float *output, float *tem_output, const int64_t *tensor_shapes, const int64_t *dst_shape, const size_t outer_size,
356   const size_t inner_size, const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin,
357   const int64_t p, const float eps, const std::string reduction, const bool swap, const bool need_broadcast,
358   const uint32_t &device_id, cudaStream_t cuda_stream);
359