• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CONTRIB_IMAGE_KERNELS_IMAGE_OPS_H_
17 #define TENSORFLOW_CONTRIB_IMAGE_KERNELS_IMAGE_OPS_H_
18 
19 // See docs in ../ops/image_ops.cc.
20 
21 #define EIGEN_USE_THREADS
22 
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 
25 #include "tensorflow/core/framework/tensor_types.h"
26 #include "tensorflow/core/platform/types.h"
27 
28 namespace tensorflow {
29 
30 namespace generator {
31 
32 enum Interpolation { INTERPOLATION_NEAREST, INTERPOLATION_BILINEAR };
33 
34 using Eigen::array;
35 using Eigen::DenseIndex;
36 
37 template <typename Device, typename T>
38 class ProjectiveGenerator {
39  private:
40   typename TTypes<T, 4>::ConstTensor input_;
41   typename TTypes<float>::ConstMatrix transforms_;
42   const Interpolation interpolation_;
43 
44  public:
45   static const int kNumParameters = 8;
46 
47   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
ProjectiveGenerator(typename TTypes<T,4>::ConstTensor input,typename TTypes<float>::ConstMatrix transforms,const Interpolation interpolation)48   ProjectiveGenerator(typename TTypes<T, 4>::ConstTensor input,
49                       typename TTypes<float>::ConstMatrix transforms,
50                       const Interpolation interpolation)
51       : input_(input), transforms_(transforms), interpolation_(interpolation) {}
52 
53   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
operator()54   operator()(const array<DenseIndex, 4>& coords) const {
55     const int64 output_y = coords[1];
56     const int64 output_x = coords[2];
57     const float* transform =
58         transforms_.dimension(0) == 1
59             ? transforms_.data()
60             : &transforms_.data()[transforms_.dimension(1) * coords[0]];
61     float projection = transform[6] * output_x + transform[7] * output_y + 1.f;
62     if (projection == 0) {
63       // Return the fill value (0) for infinite coordinates,
64       // which are outside the input image
65       return T(0);
66     }
67     const float input_x =
68         (transform[0] * output_x + transform[1] * output_y + transform[2]) /
69         projection;
70     const float input_y =
71         (transform[3] * output_x + transform[4] * output_y + transform[5]) /
72         projection;
73 
74     const T fill_value = T(0);
75     switch (interpolation_) {
76       case INTERPOLATION_NEAREST:
77         // Switch the order of x and y again for indexing into the image.
78         return nearest_interpolation(coords[0], input_y, input_x, coords[3],
79                                      fill_value);
80       case INTERPOLATION_BILINEAR:
81         return bilinear_interpolation(coords[0], input_y, input_x, coords[3],
82                                       fill_value);
83     }
84     // Unreachable; ImageProjectiveTransform only uses INTERPOLATION_NEAREST
85     // or INTERPOLATION_BILINEAR.
86     return T(0);
87   }
88 
89  private:
90   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
nearest_interpolation(const DenseIndex batch,const float y,const float x,const DenseIndex channel,const T fill_value)91   nearest_interpolation(const DenseIndex batch, const float y, const float x,
92                         const DenseIndex channel, const T fill_value) const {
93     return read_with_fill_value(batch, DenseIndex(std::round(y)),
94                                 DenseIndex(std::round(x)), channel, fill_value);
95   }
96 
97   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
bilinear_interpolation(const DenseIndex batch,const float y,const float x,const DenseIndex channel,const T fill_value)98   bilinear_interpolation(const DenseIndex batch, const float y, const float x,
99                          const DenseIndex channel, const T fill_value) const {
100     const float y_floor = std::floor(y);
101     const float x_floor = std::floor(x);
102     const float y_ceil = y_floor + 1;
103     const float x_ceil = x_floor + 1;
104     // f(x, y_floor) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_floor)
105     //               + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_floor)
106     const float value_yfloor =
107         (x_ceil - x) * static_cast<float>(read_with_fill_value(
108                            batch, DenseIndex(y_floor), DenseIndex(x_floor),
109                            channel, fill_value)) +
110         (x - x_floor) * static_cast<float>(read_with_fill_value(
111                             batch, DenseIndex(y_floor), DenseIndex(x_ceil),
112                             channel, fill_value));
113     // f(x, y_ceil) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_ceil)
114     //              + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_ceil)
115     const float value_yceil =
116         (x_ceil - x) * static_cast<float>(read_with_fill_value(
117                            batch, DenseIndex(y_ceil), DenseIndex(x_floor),
118                            channel, fill_value)) +
119         (x - x_floor) * static_cast<float>(read_with_fill_value(
120                             batch, DenseIndex(y_ceil), DenseIndex(x_ceil),
121                             channel, fill_value));
122     // f(x, y) = (y_ceil - y) / (y_ceil - y_floor) * f(x, y_floor)
123     //         + (y - y_floor) / (y_ceil - y_floor) * f(x, y_ceil)
124     return T((y_ceil - y) * value_yfloor + (y - y_floor) * value_yceil);
125   }
126 
read_with_fill_value(const DenseIndex batch,const DenseIndex y,const DenseIndex x,const DenseIndex channel,const T fill_value)127   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T read_with_fill_value(
128       const DenseIndex batch, const DenseIndex y, const DenseIndex x,
129       const DenseIndex channel, const T fill_value) const {
130     // batch and channel must be correct, because they are passed unchanged from
131     // the input.
132     return (0 <= y && y < input_.dimension(1) && 0 <= x &&
133             x < input_.dimension(2))
134                ? input_(array<DenseIndex, 4>{batch, y, x, channel})
135                : fill_value;
136   }
137 };
138 
139 }  // end namespace generator
140 
141 // NOTE(ringwalt): We MUST wrap the generate() call in a functor and explicitly
142 // instantiate the functor in image_ops_gpu.cu.cc. Otherwise, we will be missing
143 // some Eigen device code.
144 namespace functor {
145 
146 using generator::Interpolation;
147 using generator::ProjectiveGenerator;
148 
149 template <typename Device, typename T>
150 struct FillProjectiveTransform {
151   typedef typename TTypes<T, 4>::Tensor OutputType;
152   typedef typename TTypes<T, 4>::ConstTensor InputType;
153   typedef typename TTypes<float, 2>::ConstTensor TransformsType;
154   const Interpolation interpolation_;
155 
FillProjectiveTransformFillProjectiveTransform156   FillProjectiveTransform(Interpolation interpolation)
157       : interpolation_(interpolation) {}
158 
159   EIGEN_ALWAYS_INLINE
operatorFillProjectiveTransform160   void operator()(const Device& device, OutputType* output,
161                   const InputType& images,
162                   const TransformsType& transform) const {
163     output->device(device) = output->generate(
164         ProjectiveGenerator<Device, T>(images, transform, interpolation_));
165   }
166 };
167 
168 }  // end namespace functor
169 
170 }  // end namespace tensorflow
171 
172 #endif  // TENSORFLOW_CONTRIB_IMAGE_KERNELS_IMAGE_OPS_H_
173