• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/native/UpSample.h>
5 
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #include <ATen/NativeFunctions.h>
9 #else
10 #include <ATen/ops/_empty_affine_quantized.h>
11 #include <ATen/ops/_upsample_nearest_exact3d_native.h>
12 #include <ATen/ops/upsample_nearest3d_native.h>
13 #endif
14 
15 #include <c10/util/irange.h>
16 
17 #include <cstring>
18 
19 
20 namespace at {
21 namespace native {
22 
23 // Define a typedef to dispatch to nearest_idx or nearest_exact_idx
24 typedef int64_t (*nn_compute_source_index_fn_t)(const float, int64_t, int64_t);
25 
26 // at::native functions for the native_functions.yaml
27 template <typename scalar_t, nn_compute_source_index_fn_t nn_compute_source_index_fn>
upsample_nearest3d_out_frame(scalar_t * odata,scalar_t * idata,int64_t input_depth,int64_t input_height,int64_t input_width,int64_t output_depth,int64_t output_height,int64_t output_width,int64_t nbatch,int64_t channels,std::optional<double> scales_d,std::optional<double> scales_h,std::optional<double> scales_w)28 static void upsample_nearest3d_out_frame(
29     scalar_t* odata,
30     scalar_t* idata,
31     int64_t input_depth,
32     int64_t input_height,
33     int64_t input_width,
34     int64_t output_depth,
35     int64_t output_height,
36     int64_t output_width,
37     int64_t nbatch,
38     int64_t channels,
39     std::optional<double> scales_d,
40     std::optional<double> scales_h,
41     std::optional<double> scales_w) {
42   float depth_scale = compute_scales_value<float>(scales_d, input_depth, output_depth);
43   float height_scale = compute_scales_value<float>(scales_h, input_height, output_height);
44   float width_scale = compute_scales_value<float>(scales_w, input_width, output_width);
45 
46   channels = channels * nbatch;
47   if (channels == 0 || output_depth == 0 || output_height == 0 || output_width == 0) {
48     return;
49   }
50   auto* i_p = reinterpret_cast<typename scalar_t::underlying*>(idata);
51   auto* o_p = reinterpret_cast<typename scalar_t::underlying*>(odata);
52 
53   // special case: just copy
54   if (input_depth == output_depth && input_height == output_height && input_width == output_width) {
55     std::memcpy(o_p, i_p, channels * input_depth * input_height * input_width * sizeof(typename scalar_t::underlying));
56     return;
57   }
58 
59   for (const auto d2 : c10::irange(output_depth)) {
60     const int64_t d1 =
61           nn_compute_source_index_fn(depth_scale, d2, input_depth);
62 
63     for (const auto h2 : c10::irange(output_height)) {
64       const int64_t h1 =
65           nn_compute_source_index_fn(height_scale, h2, input_height);
66 
67       for (const auto w2 : c10::irange(output_width)) {
68         const int64_t w1 =
69             nn_compute_source_index_fn(width_scale, w2, input_width);
70 
71         const auto* pos1 = &i_p[d1 * input_height * input_width + h1 * input_width + w1];
72         auto* pos2 = &o_p[d2 * output_height * output_width + h2 * output_width + w2];
73 
74         for (C10_UNUSED const auto c : c10::irange(channels)) {
75           pos2[0] = pos1[0];
76           pos1 += input_depth * input_height * input_width;
77           pos2 += output_depth * output_height * output_width;
78         }
79       }
80     }
81   }
82 }
83 
84 template <typename scalar_t, nn_compute_source_index_fn_t nn_compute_source_index_fn>
upsample_nearest3d_out_frame_nhwc(scalar_t * odata,scalar_t * idata,int64_t input_depth,int64_t input_height,int64_t input_width,int64_t output_depth,int64_t output_height,int64_t output_width,int64_t nbatch,int64_t channels,std::optional<double> scales_d,std::optional<double> scales_h,std::optional<double> scales_w)85 static void upsample_nearest3d_out_frame_nhwc(
86     scalar_t* odata,
87     scalar_t* idata,
88     int64_t input_depth,
89     int64_t input_height,
90     int64_t input_width,
91     int64_t output_depth,
92     int64_t output_height,
93     int64_t output_width,
94     int64_t nbatch,
95     int64_t channels,
96     std::optional<double> scales_d,
97     std::optional<double> scales_h,
98     std::optional<double> scales_w) {
99   float depth_scale = compute_scales_value<float>(scales_d, input_depth, output_depth);
100   float height_scale = compute_scales_value<float>(scales_h, input_height, output_height);
101   float width_scale = compute_scales_value<float>(scales_w, input_width, output_width);
102 
103   for (const auto b : c10::irange(nbatch)) {
104     auto* i_p = reinterpret_cast<typename scalar_t::underlying*>(idata + b * input_depth * input_height * input_width * channels);
105     auto* o_p = reinterpret_cast<typename scalar_t::underlying*>(odata + b * output_depth * output_height * output_width * channels);
106     // special case: just copy
107     if (input_depth == output_depth && input_height == output_height && input_width == output_width) {
108       std::memcpy(o_p, i_p, channels * input_depth * input_height * input_width * sizeof(typename scalar_t::underlying));
109       return;
110     }
111 
112     for (const auto d2 : c10::irange(output_depth)) {
113       const int64_t d1 =
114           nn_compute_source_index_fn(depth_scale, d2, input_depth);
115       for (const auto h2 : c10::irange(output_height)) {
116         const int64_t h1 =
117             nn_compute_source_index_fn(height_scale, h2, input_height);
118 
119         for (const auto w2 : c10::irange(output_width)) {
120           const int64_t w1 =
121               nn_compute_source_index_fn(width_scale, w2, input_width);
122 
123           const auto* pos1 = &i_p[(d1 * input_height * input_width + h1 * input_width + w1)*channels];
124           auto* pos2 = &o_p[(d2 * output_height * output_width + h2 * output_width + w2)*channels];
125           std::memcpy(pos2, pos1, channels * sizeof(typename scalar_t::underlying));
126         }
127       }
128     }
129   }
130 }
131 
132 template <nn_compute_source_index_fn_t nn_compute_source_index_fn>
_upsample_nearest3d_quantized_cpu(const Tensor & input,IntArrayRef output_size,std::optional<double> scales_d,std::optional<double> scales_h,std::optional<double> scales_w)133 Tensor _upsample_nearest3d_quantized_cpu(
134     const Tensor& input,
135     IntArrayRef output_size,
136     std::optional<double> scales_d,
137     std::optional<double> scales_h,
138     std::optional<double> scales_w) {
139   TORCH_CHECK(
140       output_size.size() == 3,
141       "It is expected output_size equals to 3, but got size ",
142       output_size.size());
143 
144   TORCH_CHECK(
145       input.numel() != 0 && input.dim() == 5,
146       "Non-empty 5D data tensor expected but got a tensor with sizes ",
147       input.sizes());
148 
149   int64_t output_depth = output_size[0];
150   int64_t output_height = output_size[1];
151   int64_t output_width = output_size[2];
152 
153   int64_t nbatch = input.size(0);
154   int64_t channels = input.size(1);
155   int64_t input_depth = input.size(2);
156   int64_t input_height = input.size(3);
157   int64_t input_width = input.size(4);
158   AT_ASSERT(input_width > 0 && output_width > 0);
159   if (input.is_contiguous(c10::MemoryFormat::ChannelsLast3d)) {
160     Tensor output = at::_empty_affine_quantized(
161         {nbatch, channels, output_depth, output_height, output_width},
162         input.options().memory_format(input.suggest_memory_format()),
163         input.q_scale(),
164         input.q_zero_point(),
165         std::nullopt);
166 
167     AT_DISPATCH_QINT_TYPES(input.scalar_type(), "upsample_nearest3d", [&] {
168       auto* idata = static_cast<scalar_t*>(input.data_ptr());
169       auto* odata = static_cast<scalar_t*>(output.data_ptr());
170       upsample_nearest3d_out_frame_nhwc<scalar_t, nn_compute_source_index_fn>(
171           odata,
172           idata,
173           input_depth,
174           input_height,
175           input_width,
176           output_depth,
177           output_height,
178           output_width,
179           nbatch,
180           channels,
181           scales_d,
182           scales_h,
183           scales_w);
184     });
185     return output;
186   } else {
187     Tensor output = at::_empty_affine_quantized(
188         {nbatch, channels, output_depth, output_height, output_width},
189         input.options(),
190         input.q_scale(),
191         input.q_zero_point());
192 
193     auto input_contig = input.contiguous();
194 
195     AT_DISPATCH_QINT_TYPES(input_contig.scalar_type(), "upsample_nearest3d", [&] {
196       auto* idata = static_cast<scalar_t*>(input_contig.data_ptr());
197       auto* odata = static_cast<scalar_t*>(output.data_ptr());
198       upsample_nearest3d_out_frame<scalar_t, nn_compute_source_index_fn>(
199           odata,
200           idata,
201           input_depth,
202           input_height,
203           input_width,
204           output_depth,
205           output_height,
206           output_width,
207           nbatch,
208           channels,
209           scales_d,
210           scales_h,
211           scales_w);
212     });
213     return output;
214   }
215 }
216 
upsample_nearest3d_quantized_cpu(const Tensor & input,IntArrayRef osize,std::optional<double> scale_d,std::optional<double> scale_h,std::optional<double> scale_w)217 Tensor upsample_nearest3d_quantized_cpu(
218     const Tensor& input,
219     IntArrayRef osize,
220     std::optional<double> scale_d,
221     std::optional<double> scale_h,
222     std::optional<double> scale_w) {
223   return _upsample_nearest3d_quantized_cpu<nearest_neighbor_compute_source_index>(
224       input, osize, scale_d, scale_h, scale_w);
225 }
226 
_upsample_nearest_exact3d_quantized_cpu(const Tensor & input,IntArrayRef osize,std::optional<double> scale_d,std::optional<double> scale_h,std::optional<double> scale_w)227 Tensor _upsample_nearest_exact3d_quantized_cpu(
228     const Tensor& input,
229     IntArrayRef osize,
230     std::optional<double> scale_d,
231     std::optional<double> scale_h,
232     std::optional<double> scale_w) {
233   return _upsample_nearest3d_quantized_cpu<nearest_neighbor_exact_compute_source_index>(
234       input, osize, scale_d, scale_h, scale_w);
235 }
236 
237 } // namespace native
238 } // namespace at
239