1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/native/UpSample.h>
5
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #include <ATen/NativeFunctions.h>
9 #else
10 #include <ATen/ops/_empty_affine_quantized.h>
11 #include <ATen/ops/_upsample_nearest_exact3d_native.h>
12 #include <ATen/ops/upsample_nearest3d_native.h>
13 #endif
14
15 #include <c10/util/irange.h>
16
17 #include <cstring>
18
19
20 namespace at {
21 namespace native {
22
23 // Define a typedef to dispatch to nearest_idx or nearest_exact_idx
24 typedef int64_t (*nn_compute_source_index_fn_t)(const float, int64_t, int64_t);
25
26 // at::native functions for the native_functions.yaml
27 template <typename scalar_t, nn_compute_source_index_fn_t nn_compute_source_index_fn>
upsample_nearest3d_out_frame(scalar_t * odata,scalar_t * idata,int64_t input_depth,int64_t input_height,int64_t input_width,int64_t output_depth,int64_t output_height,int64_t output_width,int64_t nbatch,int64_t channels,std::optional<double> scales_d,std::optional<double> scales_h,std::optional<double> scales_w)28 static void upsample_nearest3d_out_frame(
29 scalar_t* odata,
30 scalar_t* idata,
31 int64_t input_depth,
32 int64_t input_height,
33 int64_t input_width,
34 int64_t output_depth,
35 int64_t output_height,
36 int64_t output_width,
37 int64_t nbatch,
38 int64_t channels,
39 std::optional<double> scales_d,
40 std::optional<double> scales_h,
41 std::optional<double> scales_w) {
42 float depth_scale = compute_scales_value<float>(scales_d, input_depth, output_depth);
43 float height_scale = compute_scales_value<float>(scales_h, input_height, output_height);
44 float width_scale = compute_scales_value<float>(scales_w, input_width, output_width);
45
46 channels = channels * nbatch;
47 if (channels == 0 || output_depth == 0 || output_height == 0 || output_width == 0) {
48 return;
49 }
50 auto* i_p = reinterpret_cast<typename scalar_t::underlying*>(idata);
51 auto* o_p = reinterpret_cast<typename scalar_t::underlying*>(odata);
52
53 // special case: just copy
54 if (input_depth == output_depth && input_height == output_height && input_width == output_width) {
55 std::memcpy(o_p, i_p, channels * input_depth * input_height * input_width * sizeof(typename scalar_t::underlying));
56 return;
57 }
58
59 for (const auto d2 : c10::irange(output_depth)) {
60 const int64_t d1 =
61 nn_compute_source_index_fn(depth_scale, d2, input_depth);
62
63 for (const auto h2 : c10::irange(output_height)) {
64 const int64_t h1 =
65 nn_compute_source_index_fn(height_scale, h2, input_height);
66
67 for (const auto w2 : c10::irange(output_width)) {
68 const int64_t w1 =
69 nn_compute_source_index_fn(width_scale, w2, input_width);
70
71 const auto* pos1 = &i_p[d1 * input_height * input_width + h1 * input_width + w1];
72 auto* pos2 = &o_p[d2 * output_height * output_width + h2 * output_width + w2];
73
74 for (C10_UNUSED const auto c : c10::irange(channels)) {
75 pos2[0] = pos1[0];
76 pos1 += input_depth * input_height * input_width;
77 pos2 += output_depth * output_height * output_width;
78 }
79 }
80 }
81 }
82 }
83
84 template <typename scalar_t, nn_compute_source_index_fn_t nn_compute_source_index_fn>
upsample_nearest3d_out_frame_nhwc(scalar_t * odata,scalar_t * idata,int64_t input_depth,int64_t input_height,int64_t input_width,int64_t output_depth,int64_t output_height,int64_t output_width,int64_t nbatch,int64_t channels,std::optional<double> scales_d,std::optional<double> scales_h,std::optional<double> scales_w)85 static void upsample_nearest3d_out_frame_nhwc(
86 scalar_t* odata,
87 scalar_t* idata,
88 int64_t input_depth,
89 int64_t input_height,
90 int64_t input_width,
91 int64_t output_depth,
92 int64_t output_height,
93 int64_t output_width,
94 int64_t nbatch,
95 int64_t channels,
96 std::optional<double> scales_d,
97 std::optional<double> scales_h,
98 std::optional<double> scales_w) {
99 float depth_scale = compute_scales_value<float>(scales_d, input_depth, output_depth);
100 float height_scale = compute_scales_value<float>(scales_h, input_height, output_height);
101 float width_scale = compute_scales_value<float>(scales_w, input_width, output_width);
102
103 for (const auto b : c10::irange(nbatch)) {
104 auto* i_p = reinterpret_cast<typename scalar_t::underlying*>(idata + b * input_depth * input_height * input_width * channels);
105 auto* o_p = reinterpret_cast<typename scalar_t::underlying*>(odata + b * output_depth * output_height * output_width * channels);
106 // special case: just copy
107 if (input_depth == output_depth && input_height == output_height && input_width == output_width) {
108 std::memcpy(o_p, i_p, channels * input_depth * input_height * input_width * sizeof(typename scalar_t::underlying));
109 return;
110 }
111
112 for (const auto d2 : c10::irange(output_depth)) {
113 const int64_t d1 =
114 nn_compute_source_index_fn(depth_scale, d2, input_depth);
115 for (const auto h2 : c10::irange(output_height)) {
116 const int64_t h1 =
117 nn_compute_source_index_fn(height_scale, h2, input_height);
118
119 for (const auto w2 : c10::irange(output_width)) {
120 const int64_t w1 =
121 nn_compute_source_index_fn(width_scale, w2, input_width);
122
123 const auto* pos1 = &i_p[(d1 * input_height * input_width + h1 * input_width + w1)*channels];
124 auto* pos2 = &o_p[(d2 * output_height * output_width + h2 * output_width + w2)*channels];
125 std::memcpy(pos2, pos1, channels * sizeof(typename scalar_t::underlying));
126 }
127 }
128 }
129 }
130 }
131
132 template <nn_compute_source_index_fn_t nn_compute_source_index_fn>
_upsample_nearest3d_quantized_cpu(const Tensor & input,IntArrayRef output_size,std::optional<double> scales_d,std::optional<double> scales_h,std::optional<double> scales_w)133 Tensor _upsample_nearest3d_quantized_cpu(
134 const Tensor& input,
135 IntArrayRef output_size,
136 std::optional<double> scales_d,
137 std::optional<double> scales_h,
138 std::optional<double> scales_w) {
139 TORCH_CHECK(
140 output_size.size() == 3,
141 "It is expected output_size equals to 3, but got size ",
142 output_size.size());
143
144 TORCH_CHECK(
145 input.numel() != 0 && input.dim() == 5,
146 "Non-empty 5D data tensor expected but got a tensor with sizes ",
147 input.sizes());
148
149 int64_t output_depth = output_size[0];
150 int64_t output_height = output_size[1];
151 int64_t output_width = output_size[2];
152
153 int64_t nbatch = input.size(0);
154 int64_t channels = input.size(1);
155 int64_t input_depth = input.size(2);
156 int64_t input_height = input.size(3);
157 int64_t input_width = input.size(4);
158 AT_ASSERT(input_width > 0 && output_width > 0);
159 if (input.is_contiguous(c10::MemoryFormat::ChannelsLast3d)) {
160 Tensor output = at::_empty_affine_quantized(
161 {nbatch, channels, output_depth, output_height, output_width},
162 input.options().memory_format(input.suggest_memory_format()),
163 input.q_scale(),
164 input.q_zero_point(),
165 std::nullopt);
166
167 AT_DISPATCH_QINT_TYPES(input.scalar_type(), "upsample_nearest3d", [&] {
168 auto* idata = static_cast<scalar_t*>(input.data_ptr());
169 auto* odata = static_cast<scalar_t*>(output.data_ptr());
170 upsample_nearest3d_out_frame_nhwc<scalar_t, nn_compute_source_index_fn>(
171 odata,
172 idata,
173 input_depth,
174 input_height,
175 input_width,
176 output_depth,
177 output_height,
178 output_width,
179 nbatch,
180 channels,
181 scales_d,
182 scales_h,
183 scales_w);
184 });
185 return output;
186 } else {
187 Tensor output = at::_empty_affine_quantized(
188 {nbatch, channels, output_depth, output_height, output_width},
189 input.options(),
190 input.q_scale(),
191 input.q_zero_point());
192
193 auto input_contig = input.contiguous();
194
195 AT_DISPATCH_QINT_TYPES(input_contig.scalar_type(), "upsample_nearest3d", [&] {
196 auto* idata = static_cast<scalar_t*>(input_contig.data_ptr());
197 auto* odata = static_cast<scalar_t*>(output.data_ptr());
198 upsample_nearest3d_out_frame<scalar_t, nn_compute_source_index_fn>(
199 odata,
200 idata,
201 input_depth,
202 input_height,
203 input_width,
204 output_depth,
205 output_height,
206 output_width,
207 nbatch,
208 channels,
209 scales_d,
210 scales_h,
211 scales_w);
212 });
213 return output;
214 }
215 }
216
upsample_nearest3d_quantized_cpu(const Tensor & input,IntArrayRef osize,std::optional<double> scale_d,std::optional<double> scale_h,std::optional<double> scale_w)217 Tensor upsample_nearest3d_quantized_cpu(
218 const Tensor& input,
219 IntArrayRef osize,
220 std::optional<double> scale_d,
221 std::optional<double> scale_h,
222 std::optional<double> scale_w) {
223 return _upsample_nearest3d_quantized_cpu<nearest_neighbor_compute_source_index>(
224 input, osize, scale_d, scale_h, scale_w);
225 }
226
_upsample_nearest_exact3d_quantized_cpu(const Tensor & input,IntArrayRef osize,std::optional<double> scale_d,std::optional<double> scale_h,std::optional<double> scale_w)227 Tensor _upsample_nearest_exact3d_quantized_cpu(
228 const Tensor& input,
229 IntArrayRef osize,
230 std::optional<double> scale_d,
231 std::optional<double> scale_h,
232 std::optional<double> scale_w) {
233 return _upsample_nearest3d_quantized_cpu<nearest_neighbor_exact_compute_source_index>(
234 input, osize, scale_d, scale_h, scale_w);
235 }
236
237 } // namespace native
238 } // namespace at
239