1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 #if GOOGLE_CUDA
19 #define EIGEN_USE_GPU
20 #endif // GOOGLE_CUDA
21
22 #include "tensorflow/core/kernels/image/image_ops.h"
23
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/platform/types.h"
28
29 namespace tensorflow {
30
31 namespace functor {
32
33 // Explicit instantiation of the CPU functor.
34 typedef Eigen::ThreadPoolDevice CPUDevice;
35
36 template struct FillProjectiveTransform<CPUDevice, uint8>;
37 template struct FillProjectiveTransform<CPUDevice, int32>;
38 template struct FillProjectiveTransform<CPUDevice, int64_t>;
39 template struct FillProjectiveTransform<CPUDevice, Eigen::half>;
40 template struct FillProjectiveTransform<CPUDevice, float>;
41 template struct FillProjectiveTransform<CPUDevice, double>;
42
43 } // end namespace functor
44
45 typedef Eigen::ThreadPoolDevice CPUDevice;
46
47 using functor::FillProjectiveTransform;
48 using generator::Interpolation;
49 using generator::Mode;
50
51 template <typename Device, typename T>
DoImageProjectiveTransformOp(OpKernelContext * ctx,const Interpolation & interpolation,const Mode & fill_mode)52 void DoImageProjectiveTransformOp(OpKernelContext* ctx,
53 const Interpolation& interpolation,
54 const Mode& fill_mode) {
55 const Tensor& images_t = ctx->input(0);
56 const Tensor& transform_t = ctx->input(1);
57 OP_REQUIRES(ctx, images_t.shape().dims() == 4,
58 errors::InvalidArgument("Input images must have rank 4"));
59 OP_REQUIRES(ctx,
60 (TensorShapeUtils::IsMatrix(transform_t.shape()) &&
61 (transform_t.dim_size(0) == images_t.dim_size(0) ||
62 transform_t.dim_size(0) == 1) &&
63 transform_t.dim_size(1) == 8),
64 errors::InvalidArgument(
65 "Input transform should be num_images x 8 or 1 x 8"));
66
67 int32_t out_height, out_width;
68 // Kernel is shared by legacy "ImageProjectiveTransform" op with 2 args.
69 if (ctx->num_inputs() >= 3) {
70 const Tensor& shape_t = ctx->input(2);
71 OP_REQUIRES(ctx, shape_t.dims() == 1,
72 errors::InvalidArgument("output shape must be 1-dimensional",
73 shape_t.shape().DebugString()));
74 OP_REQUIRES(ctx, shape_t.NumElements() == 2,
75 errors::InvalidArgument("output shape must have two elements",
76 shape_t.shape().DebugString()));
77 auto shape_vec = shape_t.vec<int32>();
78 out_height = shape_vec(0);
79 out_width = shape_vec(1);
80 OP_REQUIRES(ctx, out_height > 0 && out_width > 0,
81 errors::InvalidArgument("output dimensions must be positive"));
82 } else {
83 // Shape is N (batch size), H (height), W (width), C (channels).
84 out_height = images_t.shape().dim_size(1);
85 out_width = images_t.shape().dim_size(2);
86 }
87
88 T fill_value(0);
89 // Kernel is shared by "ImageProjectiveTransformV2" with 3 args.
90 if (ctx->num_inputs() >= 4) {
91 const Tensor& fill_value_t = ctx->input(3);
92 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(fill_value_t.shape()),
93 errors::InvalidArgument("fill_value must be a scalar",
94 fill_value_t.shape().DebugString()));
95 fill_value = static_cast<T>(*(fill_value_t.scalar<float>().data()));
96 }
97
98 Tensor* output_t;
99 OP_REQUIRES_OK(
100 ctx, ctx->allocate_output(0,
101 TensorShape({images_t.dim_size(0), out_height,
102 out_width, images_t.dim_size(3)}),
103 &output_t));
104 auto output = output_t->tensor<T, 4>();
105 auto images = images_t.tensor<T, 4>();
106 auto transform = transform_t.matrix<float>();
107
108 (FillProjectiveTransform<Device, T>(interpolation))(
109 ctx->eigen_device<Device>(), &output, images, transform, fill_mode,
110 fill_value);
111 }
112
113 template <typename Device, typename T>
114 class ImageProjectiveTransformV2 : public OpKernel {
115 private:
116 Interpolation interpolation_;
117 Mode fill_mode_;
118
119 public:
ImageProjectiveTransformV2(OpKernelConstruction * ctx)120 explicit ImageProjectiveTransformV2(OpKernelConstruction* ctx)
121 : OpKernel(ctx) {
122 string interpolation_str;
123 OP_REQUIRES_OK(ctx, ctx->GetAttr("interpolation", &interpolation_str));
124 if (interpolation_str == "NEAREST") {
125 interpolation_ = Interpolation::NEAREST;
126 } else if (interpolation_str == "BILINEAR") {
127 interpolation_ = Interpolation::BILINEAR;
128 } else {
129 LOG(ERROR) << "Invalid interpolation " << interpolation_str
130 << ". Supported types: NEAREST, BILINEAR";
131 }
132 string mode_str;
133 OP_REQUIRES_OK(ctx, ctx->GetAttr("fill_mode", &mode_str));
134 if (mode_str == "REFLECT") {
135 fill_mode_ = Mode::FILL_REFLECT;
136 } else if (mode_str == "WRAP") {
137 fill_mode_ = Mode::FILL_WRAP;
138 } else if (mode_str == "CONSTANT") {
139 fill_mode_ = Mode::FILL_CONSTANT;
140 } else if (mode_str == "NEAREST") {
141 fill_mode_ = Mode::FILL_NEAREST;
142 } else {
143 LOG(ERROR) << "Invalid mode " << mode_str
144 << ". Supported types: REFLECT, WRAP, CONSTANT, NEAREST";
145 }
146 }
147
Compute(OpKernelContext * ctx)148 void Compute(OpKernelContext* ctx) override {
149 DoImageProjectiveTransformOp<Device, T>(ctx, interpolation_, fill_mode_);
150 }
151 };
152
153 #define REGISTER(TYPE) \
154 REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV2") \
155 .Device(DEVICE_CPU) \
156 .TypeConstraint<TYPE>("dtype"), \
157 ImageProjectiveTransformV2<CPUDevice, TYPE>)
158
159 TF_CALL_uint8(REGISTER);
160 TF_CALL_int32(REGISTER);
161 TF_CALL_int64(REGISTER);
162 TF_CALL_half(REGISTER);
163 TF_CALL_float(REGISTER);
164 TF_CALL_double(REGISTER);
165
166 #undef REGISTER
167
168 template <typename Device, typename T>
169 class ImageProjectiveTransformV3
170 : public ImageProjectiveTransformV2<Device, T> {
171 public:
ImageProjectiveTransformV3(OpKernelConstruction * ctx)172 explicit ImageProjectiveTransformV3(OpKernelConstruction* ctx)
173 : ImageProjectiveTransformV2<Device, T>(ctx) {}
174 };
175
176 #define REGISTER(TYPE) \
177 REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV3") \
178 .Device(DEVICE_CPU) \
179 .TypeConstraint<TYPE>("dtype"), \
180 ImageProjectiveTransformV3<CPUDevice, TYPE>)
181
182 TF_CALL_uint8(REGISTER);
183 TF_CALL_int32(REGISTER);
184 TF_CALL_int64(REGISTER);
185 TF_CALL_half(REGISTER);
186 TF_CALL_float(REGISTER);
187 TF_CALL_double(REGISTER);
188
189 #undef REGISTER
190
191 #if GOOGLE_CUDA
192
193 typedef Eigen::GpuDevice GPUDevice;
194 typedef generator::Mode Mode;
195
196 namespace functor {
197
198 // NOTE(ringwalt): We get an undefined symbol error if we don't explicitly
199 // instantiate the operator() in GCC'd code.
200 #define DECLARE_PROJECT_FUNCTOR(TYPE) \
201 template <> \
202 void FillProjectiveTransform<GPUDevice, TYPE>::operator()( \
203 const GPUDevice& device, OutputType* output, const InputType& images, \
204 const TransformsType& transform, const Mode fill_mode, \
205 const TYPE fill_value) const; \
206 extern template struct FillProjectiveTransform<GPUDevice, TYPE>
207
208 TF_CALL_uint8(DECLARE_PROJECT_FUNCTOR);
209 TF_CALL_int32(DECLARE_PROJECT_FUNCTOR);
210 TF_CALL_int64(DECLARE_PROJECT_FUNCTOR);
211 TF_CALL_half(DECLARE_PROJECT_FUNCTOR);
212 TF_CALL_float(DECLARE_PROJECT_FUNCTOR);
213 TF_CALL_double(DECLARE_PROJECT_FUNCTOR);
214
215 } // end namespace functor
216
217 namespace generator {
218
219 #define DECLARE_MAP_FUNCTOR(Mode) \
220 template <> \
221 float MapCoordinate<GPUDevice, Mode>::operator()(const float out_coord, \
222 const DenseIndex len); \
223 extern template struct MapCoordinate<GPUDevice, Mode>
224
225 DECLARE_MAP_FUNCTOR(Mode::FILL_REFLECT);
226 DECLARE_MAP_FUNCTOR(Mode::FILL_WRAP);
227 DECLARE_MAP_FUNCTOR(Mode::FILL_CONSTANT);
228 DECLARE_MAP_FUNCTOR(Mode::FILL_NEAREST);
229
230 } // end namespace generator
231
232 #define REGISTER(TYPE) \
233 REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV2") \
234 .Device(DEVICE_GPU) \
235 .TypeConstraint<TYPE>("dtype") \
236 .HostMemory("output_shape"), \
237 ImageProjectiveTransformV2<GPUDevice, TYPE>)
238
239 TF_CALL_uint8(REGISTER);
240 TF_CALL_int32(REGISTER);
241 TF_CALL_int64(REGISTER);
242 TF_CALL_half(REGISTER);
243 TF_CALL_float(REGISTER);
244 TF_CALL_double(REGISTER);
245
246 #undef REGISTER
247
248 #define REGISTER(TYPE) \
249 REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV3") \
250 .Device(DEVICE_GPU) \
251 .TypeConstraint<TYPE>("dtype") \
252 .HostMemory("output_shape") \
253 .HostMemory("fill_value"), \
254 ImageProjectiveTransformV3<GPUDevice, TYPE>)
255
256 TF_CALL_uint8(REGISTER);
257 TF_CALL_int32(REGISTER);
258 TF_CALL_int64(REGISTER);
259 TF_CALL_half(REGISTER);
260 TF_CALL_float(REGISTER);
261 TF_CALL_double(REGISTER);
262
263 #undef REGISTER
264
265 #endif // GOOGLE_CUDA
266
267 } // end namespace tensorflow
268