• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Adapted from interp.cpp from Caffe util by Pauline Luc
2 // Originally developed by George Papandreou
3 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
4 #include <ATen/core/Tensor.h>
5 #include <ATen/AccumulateType.h>
6 #include <ATen/ceil_div.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/TensorUtils.h>
9 #include <ATen/Utils.h>
10 #include <ATen/cuda/Atomic.cuh>
11 #include <ATen/cuda/CUDAContext.h>
12 #include <ATen/native/cuda/UpSample.cuh>
13 
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #include <ATen/NativeFunctions.h>
17 #else
18 #include <ATen/ops/upsample_linear1d_native.h>
19 #include <ATen/ops/upsample_linear1d_backward_native.h>
20 #endif
21 
22 namespace at::native {
23 namespace {
24 
25 template <typename scalar_t, typename accscalar_t>
26 C10_LAUNCH_BOUNDS_1(512)
upsample_linear1d_out_frame(const int n,const accscalar_t rwidth,const bool align_corners,const PackedTensorAccessor64<const scalar_t,3> idata,PackedTensorAccessor64<scalar_t,3> odata)27 __global__ void upsample_linear1d_out_frame(
28     const int n,
29     const accscalar_t rwidth,
30     const bool align_corners,
31     const PackedTensorAccessor64<const scalar_t, 3> idata,
32     PackedTensorAccessor64<scalar_t, 3> odata) {
33   int index = threadIdx.x + blockIdx.x * blockDim.x;
34 
35   const int batchsize = idata.size(0);
36   const int channels = idata.size(1);
37   const int width1 = idata.size(2);
38   const int width2 = odata.size(2);
39 
40   if (index < n) {
41     const int w2 = index % width2;
42     // special case: just copy
43     if (width1 == width2) {
44       const int w1 = w2;
45       for (int n = 0; n < batchsize; n++) {
46         for (int c = 0; c < channels; ++c) {
47           const scalar_t val = idata[n][c][w1];
48           odata[n][c][w2] = val;
49         }
50       }
51       return;
52     }
53     //
54     const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
55         rwidth, w2, align_corners, /*cubic=*/false);
56     const int w1 = w1r;
57     const int w1p = (w1 < width1 - 1) ? 1 : 0;
58     const accscalar_t w1lambda = w1r - w1;
59     const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
60     //
61     for (int n = 0; n < batchsize; n++) {
62       for (int c = 0; c < channels; ++c) {
63         const accscalar_t val =
64             w0lambda * idata[n][c][w1] + w1lambda * idata[n][c][w1 + w1p];
65         odata[n][c][w2] = static_cast<scalar_t>(val);
66       }
67     }
68   }
69 }
70 
71 // Backward (adjoint) operation 1 <- 2 (accumulates)
72 template <typename scalar_t, typename accscalar_t>
73 C10_LAUNCH_BOUNDS_1(512)
upsample_linear1d_out_frame_backward(const int n,const accscalar_t rwidth,const bool align_corners,PackedTensorAccessor64<scalar_t,3> idata,const PackedTensorAccessor64<const scalar_t,3> odata)74 __global__ void upsample_linear1d_out_frame_backward(
75     const int n,
76     const accscalar_t rwidth,
77     const bool align_corners,
78     PackedTensorAccessor64<scalar_t, 3> idata,
79     const PackedTensorAccessor64<const scalar_t, 3> odata) {
80   int index = threadIdx.x + blockIdx.x * blockDim.x;
81 
82   const int batchsize = idata.size(0);
83   const int channels = idata.size(1);
84   const int width1 = idata.size(2);
85   const int width2 = odata.size(2);
86 
87   if (index < n) {
88     const int w2 = index % width2;
89     // special case: just copy
90     if (width1 == width2) {
91       const int w1 = w2;
92       for (int n = 0; n < batchsize; n++) {
93         for (int c = 0; c < channels; ++c) {
94           const scalar_t val = odata[n][c][w1];
95           idata[n][c][w2] = val;
96         }
97       }
98       return;
99     }
100     //
101     const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
102         rwidth, w2, align_corners, /*cubic=*/false);
103     const int w1 = w1r;
104     const int w1p = (w1 < width1 - 1) ? 1 : 0;
105     const accscalar_t w1lambda = w1r - w1;
106     const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
107     //
108     for (int n = 0; n < batchsize; n++) {
109       for (int c = 0; c < channels; ++c) {
110         const scalar_t d2val = odata[n][c][w2];
111         gpuAtomicAddNoReturn(&idata[n][c][w1], static_cast<scalar_t>(w0lambda * d2val));
112         gpuAtomicAddNoReturn(
113             &idata[n][c][w1 + w1p], static_cast<scalar_t>(w1lambda * d2val));
114       }
115     }
116   }
117 }
118 
upsample_linear1d_out_cuda_template(const Tensor & output,const Tensor & input,IntArrayRef output_size,bool align_corners,std::optional<double> scales)119 static void upsample_linear1d_out_cuda_template(
120     const Tensor& output,
121     const Tensor& input,
122     IntArrayRef output_size,
123     bool align_corners,
124     std::optional<double> scales) {
125   TensorArg input_arg{input, "input", 1}, output_arg{output, "output", 2};
126   checkAllSameGPU(__func__, {input_arg, output_arg});
127 
128   int output_width = output_size[0];
129 
130   int input_width = input.size(2);
131 
132   output.zero_();
133 
134   AT_ASSERT(input_width > 0 && output_width > 0);
135 
136   const int num_kernels = output_width;
137   const int num_threads = 512;
138       //at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
139   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
140 
141   AT_DISPATCH_FLOATING_TYPES_AND2(
142       at::ScalarType::Half, at::ScalarType::BFloat16,
143       input.scalar_type(), "upsample_linear1d_out_frame", [&] {
144         using accscalar_t = at::acc_type<scalar_t, true>;
145 
146         auto idata = input.packed_accessor64<const scalar_t, 3>();
147         auto odata = output.packed_accessor64<scalar_t, 3>();
148 
149         const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
150           input_width, output_width, align_corners, scales);
151 
152         upsample_linear1d_out_frame<scalar_t, accscalar_t>
153             <<<ceil_div(num_kernels, num_threads),
154                num_threads,
155                0,
156                stream>>>(num_kernels, rwidth, align_corners, idata, odata);
157         C10_CUDA_KERNEL_LAUNCH_CHECK();
158       });
159 }
160 
upsample_linear1d_backward_out_cuda_template(const Tensor & grad_input,const Tensor & grad_output_,IntArrayRef output_size,IntArrayRef input_size,bool align_corners,std::optional<double> scales)161 static void upsample_linear1d_backward_out_cuda_template(
162     const Tensor& grad_input,
163     const Tensor& grad_output_,
164     IntArrayRef output_size,
165     IntArrayRef input_size,
166     bool align_corners,
167     std::optional<double> scales) {
168   TensorArg grad_output_arg{grad_output_, "grad_output_", 1},
169       grad_input_arg{grad_input, "grad_input", 2};
170   checkAllSameGPU(__func__, {grad_output_arg, grad_input_arg});
171 
172   int output_width = output_size[0];
173 
174   int input_width = input_size[2];
175 
176   Tensor grad_output = grad_output_.contiguous();
177 
178   grad_input.zero_();
179 
180   const int num_kernels = output_width;
181   const int num_threads = 512;
182       //at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
183   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
184 
185   AT_DISPATCH_FLOATING_TYPES_AND2(
186       at::ScalarType::Half, at::ScalarType::BFloat16,
187       grad_output.scalar_type(), "upsample_linear1d_out_frame_backward", [&] {
188         using accscalar_t = at::acc_type<scalar_t, true>;
189 
190         auto idata = grad_input.packed_accessor64<scalar_t, 3>();
191         auto odata = grad_output.packed_accessor64<const scalar_t, 3>();
192 
193         const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
194             input_width, output_width, align_corners, scales);
195 
196         upsample_linear1d_out_frame_backward<scalar_t, accscalar_t>
197             <<<ceil_div(num_kernels, num_threads),
198                num_threads,
199                0,
200                stream>>>(num_kernels, rwidth, align_corners, idata, odata);
201         C10_CUDA_KERNEL_LAUNCH_CHECK();
202       });
203 }
204 
205 } // namespace
206 
TORCH_IMPL_FUNC(upsample_linear1d_out_cuda)207 TORCH_IMPL_FUNC(upsample_linear1d_out_cuda) (
208     const Tensor& input,
209     IntArrayRef output_size,
210     bool align_corners,
211     std::optional<double> scales,
212     const Tensor& output
213 ) {
214   upsample_linear1d_out_cuda_template(output, input, output_size, align_corners, scales);
215 }
216 
TORCH_IMPL_FUNC(upsample_linear1d_backward_out_cuda)217 TORCH_IMPL_FUNC(upsample_linear1d_backward_out_cuda) (
218     const Tensor& grad_output,
219     IntArrayRef output_size,
220     IntArrayRef input_size,
221     bool align_corners,
222     std::optional<double> scales,
223     const Tensor& grad_input
224 ) {
225   // See Note [Writing Nondeterministic Operations]
226   // Nondeterministic because of atomicAdd usage
227   globalContext().alertNotDeterministic("upsample_linear1d_backward_out_cuda");
228   upsample_linear1d_backward_out_cuda_template(
229       grad_input, grad_output, output_size, input_size, align_corners, scales);
230 }
231 
232 } // namespace at::native
233