• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/image_ops.cc
17 #define EIGEN_USE_THREADS
18 
19 #include "tensorflow/core/kernels/image/resize_nearest_neighbor_op.h"
20 
21 #include <memory>
22 
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/util/image_resizer_state.h"
32 
33 namespace tensorflow {
34 
35 typedef Eigen::ThreadPoolDevice CPUDevice;
36 typedef Eigen::GpuDevice GPUDevice;
37 
38 template <typename Device, typename T>
39 class ResizeNearestNeighborOp : public OpKernel {
40  public:
ResizeNearestNeighborOp(OpKernelConstruction * context)41   explicit ResizeNearestNeighborOp(OpKernelConstruction* context)
42       : OpKernel(context) {
43     OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
44     OP_REQUIRES_OK(
45         context, context->GetAttr("half_pixel_centers", &half_pixel_centers_));
46   }
47 
Compute(OpKernelContext * context)48   void Compute(OpKernelContext* context) override {
49     const Tensor& input = context->input(0);
50     ImageResizerState st(align_corners_, half_pixel_centers_);
51     st.ValidateAndCreateOutput(context, input);
52 
53     if (!context->status().ok()) return;
54 
55     OP_REQUIRES(context, st.in_height < (1 << 24) && st.in_width < (1 << 24),
56                 errors::InvalidArgument("nearest neighbor requires max height "
57                                         "& width of 2^24"));
58 
59     // Return if the output is empty.
60     if (st.output->NumElements() == 0) return;
61 
62     typename TTypes<T, 4>::ConstTensor input_data(input.tensor<T, 4>());
63     typename TTypes<T, 4>::Tensor output_data(st.output->tensor<T, 4>());
64 
65     bool status;
66     if (half_pixel_centers_) {
67       if (align_corners_) {
68         status = functor::ResizeNearestNeighbor<Device, T,
69                                                 /*half_pixe_centers=*/true,
70                                                 /*align_corners=*/true>()(
71             context->eigen_device<Device>(), input_data, st.height_scale,
72             st.width_scale, output_data);
73       } else {
74         status = functor::ResizeNearestNeighbor<Device, T,
75                                                 /*half_pixe_centers=*/true,
76                                                 /*align_corners=*/false>()(
77             context->eigen_device<Device>(), input_data, st.height_scale,
78             st.width_scale, output_data);
79       }
80     } else {
81       if (align_corners_) {
82         status = functor::ResizeNearestNeighbor<Device, T,
83                                                 /*half_pixe_centers=*/false,
84                                                 /*align_corners=*/true>()(
85             context->eigen_device<Device>(), input_data, st.height_scale,
86             st.width_scale, output_data);
87       } else {
88         status = functor::ResizeNearestNeighbor<Device, T,
89                                                 /*half_pixe_centers=*/false,
90                                                 /*align_corners=*/false>()(
91             context->eigen_device<Device>(), input_data, st.height_scale,
92             st.width_scale, output_data);
93       }
94     }
95     if (!status) {
96       context->SetStatus(
97           errors::Internal("Failed launching ResizeNearestNeighbor"));
98     }
99   }
100 
101  private:
102   bool align_corners_;
103   bool half_pixel_centers_;
104 };
105 
106 // Helper struct to convert a bool to the correct scaler type.
107 template <bool half_pixel_centers>
108 struct BoolToScaler {};
109 
110 struct HalfPixelScalerForNN {
operator ()tensorflow::HalfPixelScalerForNN111   inline float operator()(const int x, const float scale) const {
112     // All of the nearest neigbor code below immediately follows a call to this
113     // function with a std::floor(), so instead of subtracting the 0.5 as we
114     // do in HalfPixelScale, we leave it as is, as the std::floor does the
115     // correct thing.
116     return (static_cast<float>(x) + 0.5f) * scale;
117   }
118 };
119 
120 template <>
121 struct BoolToScaler<true> {
122   typedef HalfPixelScalerForNN Scaler;
123 };
124 
125 template <>
126 struct BoolToScaler<false> {
127   typedef LegacyScaler Scaler;
128 };
129 
130 template <bool half_pixel_centers, bool align_corners>
compute_indices(const Eigen::Index out_size,const Eigen::Index in_size,const float scale,Eigen::Index * indices)131 void compute_indices(const Eigen::Index out_size, const Eigen::Index in_size,
132                      const float scale, Eigen::Index* indices) {
133   typename BoolToScaler<half_pixel_centers>::Scaler scaler;
134   for (Eigen::Index i = 0; i < out_size; ++i) {
135     Eigen::Index x = std::min(
136         (align_corners) ? static_cast<Eigen::Index>(roundf(scaler(i, scale)))
137                         : static_cast<Eigen::Index>(floorf(scaler(i, scale))),
138         in_size - 1);
139     if (half_pixel_centers) {
140       x = std::max(static_cast<Eigen::Index>(0), x);
141     }
142     indices[i] = x;
143   }
144 }
145 
146 namespace generator {
147 template <typename T, bool half_pixel_centers, bool align_corners>
148 class ResizeNearestNeighborGenerator {
149  public:
ResizeNearestNeighborGenerator(typename TTypes<T,4>::ConstTensor input,const Eigen::Index output_height,const Eigen::Index output_width,const float height_scale,const float width_scale)150   EIGEN_ALWAYS_INLINE ResizeNearestNeighborGenerator(
151       typename TTypes<T, 4>::ConstTensor input,
152       const Eigen::Index output_height, const Eigen::Index output_width,
153       const float height_scale, const float width_scale)
154       : input_(input), ys_(output_height), xs_(output_width) {
155     const Eigen::Index input_height = input.dimension(1);
156     const Eigen::Index input_width = input.dimension(2);
157     compute_indices<half_pixel_centers, align_corners>(
158         output_height, input_height, height_scale, ys_.data());
159     compute_indices<half_pixel_centers, align_corners>(
160         output_width, input_width, width_scale, xs_.data());
161   }
162 
163   EIGEN_ALWAYS_INLINE T
operator ()(const Eigen::array<Eigen::Index,4> & coords) const164   operator()(const Eigen::array<Eigen::Index, 4>& coords) const {
165     const Eigen::Index b = coords[0];
166     const Eigen::Index y = coords[1];
167     const Eigen::Index x = coords[2];
168     const Eigen::Index c = coords[3];
169 
170     const Eigen::Index in_y = ys_[y];
171     const Eigen::Index in_x = xs_[x];
172     return input_(b, in_y, in_x, c);
173   }
174 
175  private:
176   typename TTypes<T, 4>::ConstTensor input_;
177   std::vector<Eigen::Index> ys_, xs_;
178 };
179 }  // namespace generator
180 
181 // Partial specialization of ResizeNearestNeighbor functor for a CPUDevice.
182 namespace functor {
183 template <typename T, bool half_pixel_centers, bool align_corners>
184 struct ResizeNearestNeighbor<CPUDevice, T, half_pixel_centers, align_corners> {
operator ()tensorflow::functor::ResizeNearestNeighbor185   bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
186                   const float height_scale, const float width_scale,
187                   typename TTypes<T, 4>::Tensor output) {
188     const Eigen::Index output_height = output.dimension(1);
189     const Eigen::Index output_width = output.dimension(2);
190     generator::ResizeNearestNeighborGenerator<T, half_pixel_centers,
191                                               align_corners>
192         generator(input, output_height, output_width, height_scale,
193                   width_scale);
194     output.device(d) = output.generate(std::move(generator));
195     return true;
196   }
197 };
198 }  // namespace functor
199 
200 template <typename Device, typename T>
201 class ResizeNearestNeighborOpGrad : public OpKernel {
202  public:
ResizeNearestNeighborOpGrad(OpKernelConstruction * context)203   explicit ResizeNearestNeighborOpGrad(OpKernelConstruction* context)
204       : OpKernel(context) {
205     OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
206     OP_REQUIRES_OK(
207         context, context->GetAttr("half_pixel_centers", &half_pixel_centers_));
208   }
209 
Compute(OpKernelContext * context)210   void Compute(OpKernelContext* context) override {
211     // Grab and validate the input:
212     const Tensor& input = context->input(0);
213     OP_REQUIRES(context, input.dims() == 4,
214                 errors::InvalidArgument("input must be 4-dimensional",
215                                         input.shape().DebugString()));
216 
217     // Grab and validate the output shape:
218     const Tensor& shape_t = context->input(1);
219     OP_REQUIRES(context, shape_t.dims() == 1,
220                 errors::InvalidArgument("shape_t must be 1-dimensional",
221                                         shape_t.shape().DebugString()));
222     OP_REQUIRES(context, shape_t.NumElements() == 2,
223                 errors::InvalidArgument("shape_t must have two elements",
224                                         shape_t.shape().DebugString()));
225 
226     auto sizes = shape_t.vec<int32>();
227     OP_REQUIRES(context, sizes(0) > 0 && sizes(1) > 0,
228                 errors::InvalidArgument("shape_t's elements must be positive"));
229 
230     const int64 batch_size = input.dim_size(0);
231     const int64 in_height = input.dim_size(1);
232     const int64 in_width = input.dim_size(2);
233     const int64 channels = input.dim_size(3);
234 
235     const int64 out_height = sizes(0);
236     const int64 out_width = sizes(1);
237 
238     Tensor* output = nullptr;
239     OP_REQUIRES_OK(
240         context,
241         context->allocate_output(
242             0, TensorShape({batch_size, out_height, out_width, channels}),
243             &output));
244 
245     // Return if the output is empty.
246     if (output->NumElements() == 0) return;
247 
248     typename TTypes<T, 4>::ConstTensor input_data(input.tensor<T, 4>());
249     typename TTypes<T, 4>::Tensor output_data(output->tensor<T, 4>());
250 
251     const float height_scale =
252         CalculateResizeScale(out_height, in_height, align_corners_);
253     const float width_scale =
254         CalculateResizeScale(out_width, in_width, align_corners_);
255 
256     bool status;
257     if (half_pixel_centers_) {
258       if (align_corners_) {
259         status = functor::ResizeNearestNeighborGrad<Device, T,
260                                                     /*half_pixel_centers=*/true,
261                                                     /*align_corners=*/true>()(
262             context->eigen_device<Device>(), input_data, height_scale,
263             width_scale, output_data);
264       } else {
265         status = functor::ResizeNearestNeighborGrad<Device, T,
266                                                     /*half_pixel_centers=*/true,
267                                                     /*align_corners=*/false>()(
268             context->eigen_device<Device>(), input_data, height_scale,
269             width_scale, output_data);
270       }
271     } else {
272       if (align_corners_) {
273         status =
274             functor::ResizeNearestNeighborGrad<Device, T,
275                                                /*half_pixel_centers=*/false,
276                                                /*align_corners=*/true>()(
277                 context->eigen_device<Device>(), input_data, height_scale,
278                 width_scale, output_data);
279       } else {
280         status =
281             functor::ResizeNearestNeighborGrad<Device, T,
282                                                /*half_pixel_centers=*/false,
283                                                /*align_corners=*/false>()(
284                 context->eigen_device<Device>(), input_data, height_scale,
285                 width_scale, output_data);
286       }
287     }
288     if (!status) {
289       context->SetStatus(
290           errors::Internal("Failed launching ResizeNearestNeighborGrad"));
291     }
292   }
293 
294  private:
295   bool align_corners_;
296   bool half_pixel_centers_;
297 };
298 
299 // Partial specialization of ResizeNearestNeighborGrad functor for a CPUDevice.
300 namespace functor {
301 template <typename T, bool half_pixel_centers, bool align_corners>
302 struct ResizeNearestNeighborGrad<CPUDevice, T, half_pixel_centers,
303                                  align_corners> {
operator ()tensorflow::functor::ResizeNearestNeighborGrad304   bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
305                   const float height_scale, const float width_scale,
306                   typename TTypes<T, 4>::Tensor output) {
307     typename BoolToScaler<half_pixel_centers>::Scaler scaler;
308     const Eigen::Index batch_size = input.dimension(0);
309     const Eigen::Index in_height = input.dimension(1);
310     const Eigen::Index in_width = input.dimension(2);
311     const Eigen::Index channels = input.dimension(3);
312 
313     const Eigen::Index out_height = output.dimension(1);
314     const Eigen::Index out_width = output.dimension(2);
315 
316     output.setZero();
317 
318     for (Eigen::Index y = 0; y < in_height; ++y) {
319       const Eigen::Index out_y = std::min(
320           (align_corners)
321               ? static_cast<Eigen::Index>(roundf(scaler(y, height_scale)))
322               : static_cast<Eigen::Index>(floorf(scaler(y, height_scale))),
323           out_height - 1);
324       for (Eigen::Index x = 0; x < in_width; ++x) {
325         const Eigen::Index out_x = std::min(
326             (align_corners)
327                 ? static_cast<Eigen::Index>(roundf(scaler(x, width_scale)))
328                 : static_cast<Eigen::Index>(floorf(scaler(x, width_scale))),
329             out_width - 1);
330         for (Eigen::Index b = 0; b < batch_size; ++b) {
331           for (Eigen::Index c = 0; c < channels; ++c) {
332             output(b, out_y, out_x, c) += input(b, y, x, c);
333           }
334         }
335       }
336     }
337     return true;
338   }
339 };
340 }  // namespace functor
341 
342 #define REGISTER_KERNEL(T)                                        \
343   REGISTER_KERNEL_BUILDER(Name("ResizeNearestNeighbor")           \
344                               .Device(DEVICE_CPU)                 \
345                               .TypeConstraint<T>("T")             \
346                               .HostMemory("size"),                \
347                           ResizeNearestNeighborOp<CPUDevice, T>); \
348   REGISTER_KERNEL_BUILDER(Name("ResizeNearestNeighborGrad")       \
349                               .Device(DEVICE_CPU)                 \
350                               .TypeConstraint<T>("T")             \
351                               .HostMemory("size"),                \
352                           ResizeNearestNeighborOpGrad<CPUDevice, T>);
353 
354 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
355 
356 #undef REGISTER_KERNEL
357 
358 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
359 
360 #define REGISTER_KERNEL(T)                                        \
361   REGISTER_KERNEL_BUILDER(Name("ResizeNearestNeighbor")           \
362                               .Device(DEVICE_GPU)                 \
363                               .TypeConstraint<T>("T")             \
364                               .HostMemory("size"),                \
365                           ResizeNearestNeighborOp<GPUDevice, T>); \
366   REGISTER_KERNEL_BUILDER(Name("ResizeNearestNeighborGrad")       \
367                               .Device(DEVICE_GPU)                 \
368                               .TypeConstraint<T>("T")             \
369                               .HostMemory("size"),                \
370                           ResizeNearestNeighborOpGrad<GPUDevice, T>);
371 
372 TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNEL);
373 
374 #undef REGISTER_KERNEL
375 
376 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
377 
378 }  // namespace tensorflow
379