• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #if GOOGLE_CUDA
19 #define EIGEN_USE_GPU
20 #endif  // GOOGLE_CUDA
21 
22 #include "tensorflow/core/kernels/image/image_ops.h"
23 
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/platform/types.h"
28 
29 namespace tensorflow {
30 
31 namespace functor {
32 
33 // Explicit instantiation of the CPU functor.
34 typedef Eigen::ThreadPoolDevice CPUDevice;
35 
36 template struct FillProjectiveTransform<CPUDevice, uint8>;
37 template struct FillProjectiveTransform<CPUDevice, int32>;
38 template struct FillProjectiveTransform<CPUDevice, int64_t>;
39 template struct FillProjectiveTransform<CPUDevice, Eigen::half>;
40 template struct FillProjectiveTransform<CPUDevice, float>;
41 template struct FillProjectiveTransform<CPUDevice, double>;
42 
43 }  // end namespace functor
44 
45 typedef Eigen::ThreadPoolDevice CPUDevice;
46 
47 using functor::FillProjectiveTransform;
48 using generator::Interpolation;
49 using generator::Mode;
50 
51 template <typename Device, typename T>
DoImageProjectiveTransformOp(OpKernelContext * ctx,const Interpolation & interpolation,const Mode & fill_mode)52 void DoImageProjectiveTransformOp(OpKernelContext* ctx,
53                                   const Interpolation& interpolation,
54                                   const Mode& fill_mode) {
55   const Tensor& images_t = ctx->input(0);
56   const Tensor& transform_t = ctx->input(1);
57   OP_REQUIRES(ctx, images_t.shape().dims() == 4,
58               errors::InvalidArgument("Input images must have rank 4"));
59   OP_REQUIRES(ctx,
60               (TensorShapeUtils::IsMatrix(transform_t.shape()) &&
61                (transform_t.dim_size(0) == images_t.dim_size(0) ||
62                 transform_t.dim_size(0) == 1) &&
63                transform_t.dim_size(1) == 8),
64               errors::InvalidArgument(
65                   "Input transform should be num_images x 8 or 1 x 8"));
66 
67   int32_t out_height, out_width;
68   // Kernel is shared by legacy "ImageProjectiveTransform" op with 2 args.
69   if (ctx->num_inputs() >= 3) {
70     const Tensor& shape_t = ctx->input(2);
71     OP_REQUIRES(ctx, shape_t.dims() == 1,
72                 errors::InvalidArgument("output shape must be 1-dimensional",
73                                         shape_t.shape().DebugString()));
74     OP_REQUIRES(ctx, shape_t.NumElements() == 2,
75                 errors::InvalidArgument("output shape must have two elements",
76                                         shape_t.shape().DebugString()));
77     auto shape_vec = shape_t.vec<int32>();
78     out_height = shape_vec(0);
79     out_width = shape_vec(1);
80     OP_REQUIRES(ctx, out_height > 0 && out_width > 0,
81                 errors::InvalidArgument("output dimensions must be positive"));
82   } else {
83     // Shape is N (batch size), H (height), W (width), C (channels).
84     out_height = images_t.shape().dim_size(1);
85     out_width = images_t.shape().dim_size(2);
86   }
87 
88   T fill_value(0);
89   // Kernel is shared by "ImageProjectiveTransformV2" with 3 args.
90   if (ctx->num_inputs() >= 4) {
91     const Tensor& fill_value_t = ctx->input(3);
92     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(fill_value_t.shape()),
93                 errors::InvalidArgument("fill_value must be a scalar",
94                                         fill_value_t.shape().DebugString()));
95     fill_value = static_cast<T>(*(fill_value_t.scalar<float>().data()));
96   }
97 
98   Tensor* output_t;
99   OP_REQUIRES_OK(
100       ctx, ctx->allocate_output(0,
101                                 TensorShape({images_t.dim_size(0), out_height,
102                                              out_width, images_t.dim_size(3)}),
103                                 &output_t));
104   auto output = output_t->tensor<T, 4>();
105   auto images = images_t.tensor<T, 4>();
106   auto transform = transform_t.matrix<float>();
107 
108   (FillProjectiveTransform<Device, T>(interpolation))(
109       ctx->eigen_device<Device>(), &output, images, transform, fill_mode,
110       fill_value);
111 }
112 
113 template <typename Device, typename T>
114 class ImageProjectiveTransformV2 : public OpKernel {
115  private:
116   Interpolation interpolation_;
117   Mode fill_mode_;
118 
119  public:
ImageProjectiveTransformV2(OpKernelConstruction * ctx)120   explicit ImageProjectiveTransformV2(OpKernelConstruction* ctx)
121       : OpKernel(ctx) {
122     string interpolation_str;
123     OP_REQUIRES_OK(ctx, ctx->GetAttr("interpolation", &interpolation_str));
124     if (interpolation_str == "NEAREST") {
125       interpolation_ = Interpolation::NEAREST;
126     } else if (interpolation_str == "BILINEAR") {
127       interpolation_ = Interpolation::BILINEAR;
128     } else {
129       LOG(ERROR) << "Invalid interpolation " << interpolation_str
130                  << ". Supported types: NEAREST, BILINEAR";
131     }
132     string mode_str;
133     OP_REQUIRES_OK(ctx, ctx->GetAttr("fill_mode", &mode_str));
134     if (mode_str == "REFLECT") {
135       fill_mode_ = Mode::FILL_REFLECT;
136     } else if (mode_str == "WRAP") {
137       fill_mode_ = Mode::FILL_WRAP;
138     } else if (mode_str == "CONSTANT") {
139       fill_mode_ = Mode::FILL_CONSTANT;
140     } else if (mode_str == "NEAREST") {
141       fill_mode_ = Mode::FILL_NEAREST;
142     } else {
143       LOG(ERROR) << "Invalid mode " << mode_str
144                  << ". Supported types: REFLECT, WRAP, CONSTANT, NEAREST";
145     }
146   }
147 
Compute(OpKernelContext * ctx)148   void Compute(OpKernelContext* ctx) override {
149     DoImageProjectiveTransformOp<Device, T>(ctx, interpolation_, fill_mode_);
150   }
151 };
152 
153 #define REGISTER(TYPE)                                        \
154   REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV2")  \
155                               .Device(DEVICE_CPU)             \
156                               .TypeConstraint<TYPE>("dtype"), \
157                           ImageProjectiveTransformV2<CPUDevice, TYPE>)
158 
159 TF_CALL_uint8(REGISTER);
160 TF_CALL_int32(REGISTER);
161 TF_CALL_int64(REGISTER);
162 TF_CALL_half(REGISTER);
163 TF_CALL_float(REGISTER);
164 TF_CALL_double(REGISTER);
165 
166 #undef REGISTER
167 
168 template <typename Device, typename T>
169 class ImageProjectiveTransformV3
170     : public ImageProjectiveTransformV2<Device, T> {
171  public:
ImageProjectiveTransformV3(OpKernelConstruction * ctx)172   explicit ImageProjectiveTransformV3(OpKernelConstruction* ctx)
173       : ImageProjectiveTransformV2<Device, T>(ctx) {}
174 };
175 
176 #define REGISTER(TYPE)                                        \
177   REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV3")  \
178                               .Device(DEVICE_CPU)             \
179                               .TypeConstraint<TYPE>("dtype"), \
180                           ImageProjectiveTransformV3<CPUDevice, TYPE>)
181 
182 TF_CALL_uint8(REGISTER);
183 TF_CALL_int32(REGISTER);
184 TF_CALL_int64(REGISTER);
185 TF_CALL_half(REGISTER);
186 TF_CALL_float(REGISTER);
187 TF_CALL_double(REGISTER);
188 
189 #undef REGISTER
190 
191 #if GOOGLE_CUDA
192 
193 typedef Eigen::GpuDevice GPUDevice;
194 typedef generator::Mode Mode;
195 
196 namespace functor {
197 
198 // NOTE(ringwalt): We get an undefined symbol error if we don't explicitly
199 // instantiate the operator() in GCC'd code.
200 #define DECLARE_PROJECT_FUNCTOR(TYPE)                                       \
201   template <>                                                               \
202   void FillProjectiveTransform<GPUDevice, TYPE>::operator()(                \
203       const GPUDevice& device, OutputType* output, const InputType& images, \
204       const TransformsType& transform, const Mode fill_mode,                \
205       const TYPE fill_value) const;                                         \
206   extern template struct FillProjectiveTransform<GPUDevice, TYPE>
207 
208 TF_CALL_uint8(DECLARE_PROJECT_FUNCTOR);
209 TF_CALL_int32(DECLARE_PROJECT_FUNCTOR);
210 TF_CALL_int64(DECLARE_PROJECT_FUNCTOR);
211 TF_CALL_half(DECLARE_PROJECT_FUNCTOR);
212 TF_CALL_float(DECLARE_PROJECT_FUNCTOR);
213 TF_CALL_double(DECLARE_PROJECT_FUNCTOR);
214 
215 }  // end namespace functor
216 
217 namespace generator {
218 
219 #define DECLARE_MAP_FUNCTOR(Mode)                                         \
220   template <>                                                             \
221   float MapCoordinate<GPUDevice, Mode>::operator()(const float out_coord, \
222                                                    const DenseIndex len); \
223   extern template struct MapCoordinate<GPUDevice, Mode>
224 
225 DECLARE_MAP_FUNCTOR(Mode::FILL_REFLECT);
226 DECLARE_MAP_FUNCTOR(Mode::FILL_WRAP);
227 DECLARE_MAP_FUNCTOR(Mode::FILL_CONSTANT);
228 DECLARE_MAP_FUNCTOR(Mode::FILL_NEAREST);
229 
230 }  // end namespace generator
231 
232 #define REGISTER(TYPE)                                       \
233   REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV2") \
234                               .Device(DEVICE_GPU)            \
235                               .TypeConstraint<TYPE>("dtype") \
236                               .HostMemory("output_shape"),   \
237                           ImageProjectiveTransformV2<GPUDevice, TYPE>)
238 
239 TF_CALL_uint8(REGISTER);
240 TF_CALL_int32(REGISTER);
241 TF_CALL_int64(REGISTER);
242 TF_CALL_half(REGISTER);
243 TF_CALL_float(REGISTER);
244 TF_CALL_double(REGISTER);
245 
246 #undef REGISTER
247 
248 #define REGISTER(TYPE)                                       \
249   REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV3") \
250                               .Device(DEVICE_GPU)            \
251                               .TypeConstraint<TYPE>("dtype") \
252                               .HostMemory("output_shape")    \
253                               .HostMemory("fill_value"),     \
254                           ImageProjectiveTransformV3<GPUDevice, TYPE>)
255 
256 TF_CALL_uint8(REGISTER);
257 TF_CALL_int32(REGISTER);
258 TF_CALL_int64(REGISTER);
259 TF_CALL_half(REGISTER);
260 TF_CALL_float(REGISTER);
261 TF_CALL_double(REGISTER);
262 
263 #undef REGISTER
264 
265 #endif  // GOOGLE_CUDA
266 
267 }  // end namespace tensorflow
268