• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/image_ops.cc
17 #define EIGEN_USE_THREADS
18 
19 #include "tensorflow/core/kernels/resize_nearest_neighbor_op.h"
20 
21 #include <memory>
22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/kernels/image_resizer_state.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/platform/logging.h"
31 
32 namespace tensorflow {
33 
34 typedef Eigen::ThreadPoolDevice CPUDevice;
35 typedef Eigen::GpuDevice GPUDevice;
36 
37 template <typename Device, typename T>
38 class ResizeNearestNeighborOp : public OpKernel {
39  public:
ResizeNearestNeighborOp(OpKernelConstruction * context)40   explicit ResizeNearestNeighborOp(OpKernelConstruction* context)
41       : OpKernel(context) {
42     OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
43     OP_REQUIRES_OK(
44         context, context->GetAttr("half_pixel_centers", &half_pixel_centers_));
45   }
46 
Compute(OpKernelContext * context)47   void Compute(OpKernelContext* context) override {
48     const Tensor& input = context->input(0);
49     ImageResizerState st(align_corners_, half_pixel_centers_);
50     st.ValidateAndCreateOutput(context, input);
51 
52     if (!context->status().ok()) return;
53 
54     OP_REQUIRES(context, st.in_height < (1 << 24) && st.in_width < (1 << 24),
55                 errors::InvalidArgument("nearest neighbor requires max height "
56                                         "& width of 2^24"));
57 
58     // Return if the output is empty.
59     if (st.output->NumElements() == 0) return;
60 
61     typename TTypes<T, 4>::ConstTensor input_data(input.tensor<T, 4>());
62     typename TTypes<T, 4>::Tensor output_data(st.output->tensor<T, 4>());
63 
64     bool status;
65     if (half_pixel_centers_) {
66       if (align_corners_) {
67         status = functor::ResizeNearestNeighbor<Device, T,
68                                                 /*half_pixe_centers=*/true,
69                                                 /*align_corners=*/true>()(
70             context->eigen_device<Device>(), input_data, st.height_scale,
71             st.width_scale, output_data);
72       } else {
73         status = functor::ResizeNearestNeighbor<Device, T,
74                                                 /*half_pixe_centers=*/true,
75                                                 /*align_corners=*/false>()(
76             context->eigen_device<Device>(), input_data, st.height_scale,
77             st.width_scale, output_data);
78       }
79     } else {
80       if (align_corners_) {
81         status = functor::ResizeNearestNeighbor<Device, T,
82                                                 /*half_pixe_centers=*/false,
83                                                 /*align_corners=*/true>()(
84             context->eigen_device<Device>(), input_data, st.height_scale,
85             st.width_scale, output_data);
86       } else {
87         status = functor::ResizeNearestNeighbor<Device, T,
88                                                 /*half_pixe_centers=*/false,
89                                                 /*align_corners=*/false>()(
90             context->eigen_device<Device>(), input_data, st.height_scale,
91             st.width_scale, output_data);
92       }
93     }
94     if (!status) {
95       context->SetStatus(
96           errors::Internal("Failed launching ResizeNearestNeighbor"));
97     }
98   }
99 
100  private:
101   bool align_corners_;
102   bool half_pixel_centers_;
103 };
104 
105 // Helper struct to convert a bool to the correct scaler type.
106 template <bool half_pixel_centers>
107 struct BoolToScaler {};
108 
109 struct HalfPixelScalerForNN {
operator ()tensorflow::HalfPixelScalerForNN110   inline float operator()(const int x, const float scale) const {
111     // All of the nearest neigbor code below immediately follows a call to this
112     // function with a std::floor(), so instead of subtracting the 0.5 as we
113     // do in HalfPixelScale, we leave it as is, as the std::floor does the
114     // correct thing.
115     return (static_cast<float>(x) + 0.5f) * scale;
116   }
117 };
118 
119 template <>
120 struct BoolToScaler<true> {
121   typedef HalfPixelScalerForNN Scaler;
122 };
123 
124 template <>
125 struct BoolToScaler<false> {
126   typedef LegacyScaler Scaler;
127 };
128 
129 // Partial specialization of ResizeNearestNeighbor functor for a CPUDevice.
130 namespace functor {
131 template <typename T, bool half_pixel_centers, bool align_corners>
132 struct ResizeNearestNeighbor<CPUDevice, T, half_pixel_centers, align_corners> {
operator ()tensorflow::functor::ResizeNearestNeighbor133   bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
134                   const float height_scale, const float width_scale,
135                   typename TTypes<T, 4>::Tensor output) {
136     typename BoolToScaler<half_pixel_centers>::Scaler scaler;
137     const Eigen::Index batch_size = input.dimension(0);
138     const Eigen::Index in_height = input.dimension(1);
139     const Eigen::Index in_width = input.dimension(2);
140     const Eigen::Index channels = input.dimension(3);
141 
142     const Eigen::Index out_height = output.dimension(1);
143     const Eigen::Index out_width = output.dimension(2);
144 
145     for (Eigen::Index b = 0; b < batch_size; ++b) {
146       for (Eigen::Index y = 0; y < out_height; ++y) {
147         Eigen::Index in_y = std::min(
148             (align_corners)
149                 ? static_cast<Eigen::Index>(roundf(scaler(y, height_scale)))
150                 : static_cast<Eigen::Index>(floorf(scaler(y, height_scale))),
151             in_height - 1);
152         if (half_pixel_centers) {
153           in_y = std::max(static_cast<Eigen::Index>(0), in_y);
154         }
155         for (Eigen::Index x = 0; x < out_width; ++x) {
156           Eigen::Index in_x = std::min(
157               (align_corners)
158                   ? static_cast<Eigen::Index>(roundf(scaler(x, width_scale)))
159                   : static_cast<Eigen::Index>(floorf(scaler(x, width_scale))),
160               in_width - 1);
161           if (half_pixel_centers) {
162             in_x = std::max(static_cast<Eigen::Index>(0), in_x);
163           }
164           std::copy_n(&input(b, in_y, in_x, 0), channels, &output(b, y, x, 0));
165         }
166       }
167     }
168     return true;
169   }
170 };
171 }  // namespace functor
172 
173 template <typename Device, typename T>
174 class ResizeNearestNeighborOpGrad : public OpKernel {
175  public:
ResizeNearestNeighborOpGrad(OpKernelConstruction * context)176   explicit ResizeNearestNeighborOpGrad(OpKernelConstruction* context)
177       : OpKernel(context) {
178     OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
179     OP_REQUIRES_OK(
180         context, context->GetAttr("half_pixel_centers", &half_pixel_centers_));
181   }
182 
Compute(OpKernelContext * context)183   void Compute(OpKernelContext* context) override {
184     // Grab and validate the input:
185     const Tensor& input = context->input(0);
186     OP_REQUIRES(context, input.dims() == 4,
187                 errors::InvalidArgument("input must be 4-dimensional",
188                                         input.shape().DebugString()));
189 
190     // Grab and validate the output shape:
191     const Tensor& shape_t = context->input(1);
192     OP_REQUIRES(context, shape_t.dims() == 1,
193                 errors::InvalidArgument("shape_t must be 1-dimensional",
194                                         shape_t.shape().DebugString()));
195     OP_REQUIRES(context, shape_t.NumElements() == 2,
196                 errors::InvalidArgument("shape_t must have two elements",
197                                         shape_t.shape().DebugString()));
198 
199     auto sizes = shape_t.vec<int32>();
200     OP_REQUIRES(context, sizes(0) > 0 && sizes(1) > 0,
201                 errors::InvalidArgument("shape_t's elements must be positive"));
202 
203     const int64 batch_size = input.dim_size(0);
204     const int64 in_height = input.dim_size(1);
205     const int64 in_width = input.dim_size(2);
206     const int64 channels = input.dim_size(3);
207 
208     const int64 out_height = sizes(0);
209     const int64 out_width = sizes(1);
210 
211     Tensor* output = nullptr;
212     OP_REQUIRES_OK(
213         context,
214         context->allocate_output(
215             0, TensorShape({batch_size, out_height, out_width, channels}),
216             &output));
217 
218     // Return if the output is empty.
219     if (output->NumElements() == 0) return;
220 
221     typename TTypes<T, 4>::ConstTensor input_data(input.tensor<T, 4>());
222     typename TTypes<T, 4>::Tensor output_data(output->tensor<T, 4>());
223 
224     const float height_scale =
225         CalculateResizeScale(out_height, in_height, align_corners_);
226     const float width_scale =
227         CalculateResizeScale(out_width, in_width, align_corners_);
228 
229     bool status;
230     if (half_pixel_centers_) {
231       if (align_corners_) {
232         status = functor::ResizeNearestNeighborGrad<Device, T,
233                                                     /*half_pixel_centers=*/true,
234                                                     /*align_corners=*/true>()(
235             context->eigen_device<Device>(), input_data, height_scale,
236             width_scale, output_data);
237       } else {
238         status = functor::ResizeNearestNeighborGrad<Device, T,
239                                                     /*half_pixel_centers=*/true,
240                                                     /*align_corners=*/false>()(
241             context->eigen_device<Device>(), input_data, height_scale,
242             width_scale, output_data);
243       }
244     } else {
245       if (align_corners_) {
246         status =
247             functor::ResizeNearestNeighborGrad<Device, T,
248                                                /*half_pixel_centers=*/false,
249                                                /*align_corners=*/true>()(
250                 context->eigen_device<Device>(), input_data, height_scale,
251                 width_scale, output_data);
252       } else {
253         status =
254             functor::ResizeNearestNeighborGrad<Device, T,
255                                                /*half_pixel_centers=*/false,
256                                                /*align_corners=*/false>()(
257                 context->eigen_device<Device>(), input_data, height_scale,
258                 width_scale, output_data);
259       }
260     }
261     if (!status) {
262       context->SetStatus(
263           errors::Internal("Failed launching ResizeNearestNeighborGrad"));
264     }
265   }
266 
267  private:
268   bool align_corners_;
269   bool half_pixel_centers_;
270 };
271 
272 // Partial specialization of ResizeNearestNeighborGrad functor for a CPUDevice.
273 namespace functor {
274 template <typename T, bool half_pixel_centers, bool align_corners>
275 struct ResizeNearestNeighborGrad<CPUDevice, T, half_pixel_centers,
276                                  align_corners> {
operator ()tensorflow::functor::ResizeNearestNeighborGrad277   bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
278                   const float height_scale, const float width_scale,
279                   typename TTypes<T, 4>::Tensor output) {
280     typename BoolToScaler<half_pixel_centers>::Scaler scaler;
281     const Eigen::Index batch_size = input.dimension(0);
282     const Eigen::Index in_height = input.dimension(1);
283     const Eigen::Index in_width = input.dimension(2);
284     const Eigen::Index channels = input.dimension(3);
285 
286     const Eigen::Index out_height = output.dimension(1);
287     const Eigen::Index out_width = output.dimension(2);
288 
289     output.setZero();
290 
291     for (Eigen::Index y = 0; y < in_height; ++y) {
292       const Eigen::Index out_y = std::min(
293           (align_corners)
294               ? static_cast<Eigen::Index>(roundf(scaler(y, height_scale)))
295               : static_cast<Eigen::Index>(floorf(scaler(y, height_scale))),
296           out_height - 1);
297       for (Eigen::Index x = 0; x < in_width; ++x) {
298         const Eigen::Index out_x = std::min(
299             (align_corners)
300                 ? static_cast<Eigen::Index>(roundf(scaler(x, width_scale)))
301                 : static_cast<Eigen::Index>(floorf(scaler(x, width_scale))),
302             out_width - 1);
303         for (Eigen::Index b = 0; b < batch_size; ++b) {
304           for (Eigen::Index c = 0; c < channels; ++c) {
305             output(b, out_y, out_x, c) += input(b, y, x, c);
306           }
307         }
308       }
309     }
310     return true;
311   }
312 };
313 }  // namespace functor
314 
315 #define REGISTER_KERNEL(T)                                        \
316   REGISTER_KERNEL_BUILDER(Name("ResizeNearestNeighbor")           \
317                               .Device(DEVICE_CPU)                 \
318                               .TypeConstraint<T>("T")             \
319                               .HostMemory("size"),                \
320                           ResizeNearestNeighborOp<CPUDevice, T>); \
321   REGISTER_KERNEL_BUILDER(Name("ResizeNearestNeighborGrad")       \
322                               .Device(DEVICE_CPU)                 \
323                               .TypeConstraint<T>("T")             \
324                               .HostMemory("size"),                \
325                           ResizeNearestNeighborOpGrad<CPUDevice, T>);
326 
327 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
328 
329 #undef REGISTER_KERNEL
330 
331 #if GOOGLE_CUDA
332 
333 #define REGISTER_KERNEL(T)                                        \
334   REGISTER_KERNEL_BUILDER(Name("ResizeNearestNeighbor")           \
335                               .Device(DEVICE_GPU)                 \
336                               .TypeConstraint<T>("T")             \
337                               .HostMemory("size"),                \
338                           ResizeNearestNeighborOp<GPUDevice, T>); \
339   REGISTER_KERNEL_BUILDER(Name("ResizeNearestNeighborGrad")       \
340                               .Device(DEVICE_GPU)                 \
341                               .TypeConstraint<T>("T")             \
342                               .HostMemory("size"),                \
343                           ResizeNearestNeighborOpGrad<GPUDevice, T>);
344 
345 TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNEL);
346 
347 #undef REGISTER_KERNEL
348 
349 #endif  // GOOGLE_CUDA
350 
351 }  // namespace tensorflow
352