1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // See docs in ../ops/image_ops.cc.
16 #include <math.h>
17
18 #include <cmath>
19
20 #include "tensorflow/core/framework/bounds_check.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/register_types.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/kernels/stateless_random_ops.h"
26 #include "tensorflow/core/lib/random/philox_random.h"
27 #include "tensorflow/core/lib/random/simple_philox.h"
28 #include "tensorflow/core/util/guarded_philox_random.h"
29
30 using tensorflow::random::SimplePhilox;
31
32 namespace tensorflow {
33 namespace {
34
35 // A simple Rectangle class that supplies intersection.
36 class Rectangle {
37 public:
Rectangle()38 Rectangle() { Set(0, 0, 0, 0); }
Rectangle(int xmin,int ymin,int xmax,int ymax)39 Rectangle(int xmin, int ymin, int xmax, int ymax) {
40 Set(xmin, ymin, xmax, ymax);
41 }
42
Set(int xmin,int ymin,int xmax,int ymax)43 void Set(int xmin, int ymin, int xmax, int ymax) {
44 min_x_ = xmin;
45 min_y_ = ymin;
46 max_x_ = xmax;
47 max_y_ = ymax;
48 }
49
IsEmpty() const50 bool IsEmpty() const { return min_x_ > max_x_ || min_y_ > max_y_; }
Area() const51 float Area() const {
52 return static_cast<float>((max_x_ - min_x_) * (max_y_ - min_y_));
53 }
54
Intersect(const Rectangle & r) const55 Rectangle Intersect(const Rectangle& r) const {
56 const int pmin_x = std::max(min_x_, r.min_x_);
57 const int pmin_y = std::max(min_y_, r.min_y_);
58 const int pmax_x = std::min(max_x_, r.max_x_);
59 const int pmax_y = std::min(max_y_, r.max_y_);
60
61 if (pmin_x > pmax_x || pmin_y > pmax_y) {
62 return Rectangle();
63 } else {
64 return Rectangle(pmin_x, pmin_y, pmax_x, pmax_y);
65 }
66 }
67
68 int min_x_;
69 int min_y_;
70 int max_x_;
71 int max_y_;
72 };
73
74 // Determine if the supplied cropping box covers a sufficient fraction of the
75 // the supplied bounding boxes.
SatisfiesOverlapConstraints(const Rectangle & crop,float minimum_object_covered,const std::vector<Rectangle> & bounding_boxes)76 bool SatisfiesOverlapConstraints(const Rectangle& crop,
77 float minimum_object_covered,
78 const std::vector<Rectangle>& bounding_boxes) {
79 // Reject any bounding box which contains no pixels.
80 const float kMinArea = 1.0;
81 if (crop.Area() < kMinArea) {
82 return false;
83 }
84
85 // Loop through all objects and determine if the proposed cropping box covers
86 // a sufficient fraction of one of the supplied bounding boxes.
87 bool is_object_covered = false;
88 for (const auto& bbox : bounding_boxes) {
89 const float object_area = bbox.Area();
90 if (object_area < kMinArea) {
91 continue;
92 }
93
94 const float object_covered = crop.Intersect(bbox).Area() / object_area;
95
96 if (object_covered >= minimum_object_covered) {
97 is_object_covered = true;
98 break;
99 }
100 }
101 return is_object_covered;
102 }
103
104 // Generate a random crop within the rectangle
105 // (0, 0, original_width, original_height).
106 // The minimum area of the crop will be between
107 // min_relative_crop_area * orig_width * orig_height
108 // and
109 // max_relative_crop_area * orig_width * orig_height
110 // such that its width = round(aspect_ratio * height).
111 // The diameter of the generated rectangle will be uniformly distributed between
112 // its minimum and maximum size. The center of the rectangle will be distributed
113 // uniformly within the source rectangle. The function returns false if the
114 // rectangle could not be generated with the given constraints.
GenerateRandomCrop(int original_width,int original_height,float min_relative_crop_area,float max_relative_crop_area,float aspect_ratio,SimplePhilox * random,Rectangle * crop_rect)115 bool GenerateRandomCrop(int original_width, int original_height,
116 float min_relative_crop_area,
117 float max_relative_crop_area, float aspect_ratio,
118 SimplePhilox* random, Rectangle* crop_rect) {
119 if (max_relative_crop_area <= 0.0 || aspect_ratio <= 0.0 ||
120 original_width <= 0 || original_height <= 0 ||
121 min_relative_crop_area > max_relative_crop_area) {
122 return false;
123 }
124
125 const float min_area =
126 min_relative_crop_area * original_width * original_height;
127 const float max_area =
128 max_relative_crop_area * original_width * original_height;
129
130 int height = static_cast<int>(lrintf(std::sqrt(min_area / aspect_ratio)));
131 int max_height = static_cast<int>(lrintf(std::sqrt(max_area / aspect_ratio)));
132
133 // TODO(b/140767341): Rewrite the generation logic to be more tolerant
134 // of floating point behavior.
135 if (lrintf(max_height * aspect_ratio) > original_width) {
136 // We must find the smallest max_height satisfying
137 // round(max_height * aspect_ratio) <= original_width:
138 const float kEps = 0.0000001;
139 max_height = static_cast<int>((original_width + 0.5 - kEps) / aspect_ratio);
140 // If due some precision issues, we still cannot guarantee
141 // round(max_height * aspect_ratio) <= original_width, subtract 1 from
142 // max height.
143 if (lrintf(max_height * aspect_ratio) > original_width) {
144 max_height -= 1;
145 }
146 }
147
148 if (max_height > original_height) {
149 max_height = original_height;
150 }
151
152 if (height >= max_height) {
153 height = max_height;
154 }
155
156 if (height < max_height) {
157 // We need to generate a random number in the closed range
158 // [0, max_height - height].
159 height += random->Uniform(max_height - height + 1);
160 }
161 int width = static_cast<int>(lrintf(height * aspect_ratio));
162 DCHECK_LE(width, original_width);
163
164 // Let us not fail if rounding error causes the area to be
165 // outside the constraints.
166 // Try first with a slightly bigger rectangle first.
167 float area = static_cast<float>(width * height);
168 if (area < min_area) {
169 height += 1;
170 width = static_cast<int>(lrintf(height * aspect_ratio));
171 area = width * height;
172 }
173
174 // Let us not fail if rounding error causes the area to be
175 // outside the constraints.
176 // Try first with a slightly smaller rectangle first.
177 if (area > max_area) {
178 height -= 1;
179 width = static_cast<int>(lrintf(height * aspect_ratio));
180 area = width * height;
181 }
182
183 // Now, we explored all options to rectify small rounding errors.
184 // It seems the constraints can't be satisfied: return false.
185 if (area < min_area || area > max_area || width > original_width ||
186 height > original_height || width <= 0 || height <= 0) {
187 return false;
188 }
189
190 int y = 0;
191 if (height < original_height) {
192 y = random->Uniform(original_height - height);
193 }
194 int x = 0;
195 if (width < original_width) {
196 x = random->Uniform(original_width - width);
197 }
198
199 crop_rect->min_x_ = x;
200 crop_rect->min_y_ = y;
201 crop_rect->max_x_ = x + width;
202 crop_rect->max_y_ = y + height;
203 return true;
204 }
205 } // namespace
206
207 template <typename T>
208 class SampleDistortedBoundingBoxBaseOp : public OpKernel {
209 public:
SampleDistortedBoundingBoxBaseOp(OpKernelConstruction * context)210 explicit SampleDistortedBoundingBoxBaseOp(OpKernelConstruction* context)
211 : OpKernel(context) {
212 if (context->num_inputs() == 2) {
213 OP_REQUIRES_OK(context, context->GetAttr("min_object_covered",
214 &min_object_covered_));
215 OP_REQUIRES(
216 context, min_object_covered_ >= 0,
217 errors::InvalidArgument("Min object covered must be non-negative: ",
218 min_object_covered_));
219 }
220
221 OP_REQUIRES_OK(context, context->GetAttr("use_image_if_no_bounding_boxes",
222 &use_image_if_no_bounding_boxes_));
223
224 OP_REQUIRES_OK(
225 context, context->GetAttr("aspect_ratio_range", &aspect_ratio_range_));
226 OP_REQUIRES(context, aspect_ratio_range_.size() == 2,
227 errors::InvalidArgument(
228 "Aspect ratio range field must specify 2 dimensions"));
229
230 OP_REQUIRES(
231 context, aspect_ratio_range_[0] > 0 && aspect_ratio_range_[1] > 0,
232 errors::InvalidArgument("Aspect ratio range must be non-negative: [",
233 aspect_ratio_range_[0], ", ",
234 aspect_ratio_range_[1], "]"));
235
236 OP_REQUIRES_OK(context, context->GetAttr("area_range", &area_range_));
237 OP_REQUIRES(
238 context, area_range_.size() == 2,
239 errors::InvalidArgument("Area range field must specify 2 dimensions"));
240
241 OP_REQUIRES(
242 context, area_range_[0] > 0 && area_range_[1] > 0,
243 errors::InvalidArgument("Area range must be non-negative: [",
244 area_range_[0], ", ", area_range_[1], "]"));
245
246 OP_REQUIRES(context, area_range_[0] <= 1 && area_range_[1] <= 1,
247 errors::InvalidArgument(
248 "Area range must be less then or equal to 1.0: [",
249 area_range_[0], ", ", area_range_[1], "]"));
250
251 OP_REQUIRES_OK(context, context->GetAttr("max_attempts", &max_attempts_));
252 OP_REQUIRES(context, max_attempts_ > 0,
253 errors::InvalidArgument("Max attempts must be non-negative: ",
254 max_attempts_));
255 }
256
DoCompute(OpKernelContext * context,const random::PhiloxRandom & rng)257 void DoCompute(OpKernelContext* context, const random::PhiloxRandom& rng) {
258 const Tensor& image_size = context->input(0);
259
260 OP_REQUIRES(context, image_size.dims() == 1,
261 errors::InvalidArgument("image_size must be 1-dimensional",
262 image_size.shape().DebugString()));
263 OP_REQUIRES(context, image_size.dim_size(0) == 3,
264 errors::InvalidArgument("image_size must contain 3 elements",
265 image_size.shape().DebugString()));
266
267 // Note image_size_data(2) is the depth and unused.
268 const uint64 height_raw = internal::SubtleMustCopy(image_size.flat<T>()(0));
269 const uint64 width_raw = internal::SubtleMustCopy(image_size.flat<T>()(1));
270 OP_REQUIRES(context,
271 FastBoundsCheck(height_raw, std::numeric_limits<int32>::max()),
272 errors::InvalidArgument("image height cannot be >= int32 max"));
273 OP_REQUIRES(context,
274 FastBoundsCheck(width_raw, std::numeric_limits<int32>::max()),
275 errors::InvalidArgument("image width cannot be >= int32 max"));
276 const int32_t height = static_cast<int32>(height_raw);
277 const int32_t width = static_cast<int32>(width_raw);
278
279 // Ensure that the supplied bounding boxes are sane and convert them to
280 // Rectangles.
281 const Tensor& input_boxes = context->input(1);
282 OP_REQUIRES(context, input_boxes.dims() == 3,
283 errors::InvalidArgument("input boxes must be 3-dimensional "
284 "[batch, num_boxes, coords]: ",
285 input_boxes.shape().DebugString()));
286 OP_REQUIRES(context, input_boxes.dim_size(input_boxes.dims() - 1) == 4,
287 errors::InvalidArgument(
288 "bounding boxes must have shape [4] or [*, 4], got ",
289 input_boxes.shape().DebugString()));
290
291 float min_object_covered_val = 0.0;
292 // `SampleDistortedBoundingBox` op accepts 2 inputs and has
293 // `min_object_covered` as an attribute (handled in the constructor).
294 // `SampleDistortedBoundingBoxV2` and `StatelessSampleDistortedBoundingBox`
295 // ops accept 3+ inputs, including `min_object_covered`.
296 if (context->num_inputs() >= 3) {
297 const Tensor& min_object_covered = context->input(2);
298
299 OP_REQUIRES(
300 context, TensorShapeUtils::IsScalar(min_object_covered.shape()),
301 errors::InvalidArgument("min_object_covered must be 0-D, got shape ",
302 min_object_covered.shape().DebugString()));
303
304 min_object_covered_val = min_object_covered.scalar<float>()();
305
306 OP_REQUIRES(
307 context, min_object_covered_val >= 0,
308 errors::InvalidArgument("Min object covered must be non-negative: ",
309 min_object_covered_val));
310 } else {
311 min_object_covered_val = min_object_covered_;
312 }
313
314 std::vector<Rectangle> bounding_boxes;
315 if (input_boxes.NumElements() > 0) {
316 TTypes<float>::ConstMatrix boxes = input_boxes.flat_inner_dims<float>();
317 for (int b = 0; b < boxes.dimension(0); ++b) {
318 for (int i = 0; i < 4; ++i) {
319 OP_REQUIRES(
320 context, boxes(b, i) >= 0.0 && boxes(b, i) <= 1.0,
321 errors::InvalidArgument("All bounding box coordinates must "
322 "be in [0.0, 1.0]: ",
323 boxes(b, i)));
324 }
325
326 const int32_t x_min = static_cast<int32>(boxes(b, 1) * width);
327 const int32_t y_min = static_cast<int32>(boxes(b, 0) * height);
328 const int32_t x_max = static_cast<int32>(boxes(b, 3) * width);
329 const int32_t y_max = static_cast<int32>(boxes(b, 2) * height);
330
331 bounding_boxes.push_back(Rectangle(x_min, y_min, x_max, y_max));
332 }
333 }
334
335 // Insert the entire image if no bounding boxes are supplied.
336 const Rectangle image_rect(0, 0, width, height);
337 if (bounding_boxes.empty()) {
338 OP_REQUIRES(context, use_image_if_no_bounding_boxes_,
339 errors::InvalidArgument(
340 "No bounding boxes provided as input. One must "
341 "enable use_image_if_no_bounding_boxes if you wish "
342 "to not provide any bounding boxes."));
343 bounding_boxes.push_back(image_rect);
344 }
345
346 const float min_sample_area = area_range_[0];
347 const float max_sample_area = area_range_[1];
348 const float min_sample_aspect_ratio = aspect_ratio_range_[0];
349 const float max_sample_aspect_ratio = aspect_ratio_range_[1];
350
351 auto local_rng = rng;
352 random::SimplePhilox random(&local_rng);
353
354 Rectangle crop_rect;
355 bool sample_generated = false;
356 for (int i = 0; i < max_attempts_; ++i) {
357 const float sample_aspect_ratio =
358 random.RandFloat() *
359 (max_sample_aspect_ratio - min_sample_aspect_ratio) +
360 min_sample_aspect_ratio;
361
362 if (GenerateRandomCrop(width, height, min_sample_area, max_sample_area,
363 sample_aspect_ratio, &random, &crop_rect)) {
364 if (SatisfiesOverlapConstraints(crop_rect, min_object_covered_val,
365 bounding_boxes)) {
366 sample_generated = true;
367 break;
368 }
369 }
370 }
371
372 if (!sample_generated) {
373 crop_rect = image_rect;
374 }
375
376 // Determine the cropping parameters from the bounding box.
377 const int target_width = crop_rect.max_x_ - crop_rect.min_x_;
378 const int target_height = crop_rect.max_y_ - crop_rect.min_y_;
379
380 const int offset_width = crop_rect.min_x_;
381 const int offset_height = crop_rect.min_y_;
382
383 // Ensure that the bounding box fits in the image dimensions.
384 OP_REQUIRES(context, width >= target_width + offset_width,
385 errors::FailedPrecondition(
386 "width must be > target_width + offset_width: ", width,
387 "vs ", target_width, " + ", offset_width));
388 OP_REQUIRES(context, height >= target_height + offset_height,
389 errors::FailedPrecondition(
390 "height must be >= target_height: height = ", height, "vs ",
391 target_height, " + ", offset_height));
392
393 // Create two vectors, each 3 elements, to provide as arguments to Slice.
394 // See Slice() operation for details.
395 Tensor* begin = nullptr;
396 OP_REQUIRES_OK(context,
397 context->allocate_output(0, TensorShape({3}), &begin));
398 Tensor* size = nullptr;
399 OP_REQUIRES_OK(context,
400 context->allocate_output(1, TensorShape({3}), &size));
401 Tensor* bboxes = nullptr;
402 OP_REQUIRES_OK(
403 context, context->allocate_output(2, TensorShape({1, 1, 4}), &bboxes));
404
405 typename TTypes<T, 1>::Tensor begin_data(begin->tensor<T, 1>());
406 typename TTypes<T, 1>::Tensor size_data(size->tensor<T, 1>());
407 TTypes<float, 3>::Tensor bboxes_data = bboxes->tensor<float, 3>();
408
409 begin_data(0) = T(offset_height);
410 size_data(0) = T(target_height);
411
412 begin_data(1) = T(offset_width);
413 size_data(1) = T(target_width);
414
415 bboxes_data(0, 0, 0) =
416 static_cast<float>(crop_rect.min_y_) / static_cast<float>(height);
417 bboxes_data(0, 0, 1) =
418 static_cast<float>(crop_rect.min_x_) / static_cast<float>(width);
419 bboxes_data(0, 0, 2) =
420 static_cast<float>(crop_rect.max_y_) / static_cast<float>(height);
421 bboxes_data(0, 0, 3) =
422 static_cast<float>(crop_rect.max_x_) / static_cast<float>(width);
423
424 // Retain all of the channels.
425 begin_data(2) = T(0);
426 size_data(2) = T(-1);
427 }
428
429 protected:
430 int32 max_attempts_;
431 std::vector<float> area_range_;
432 std::vector<float> aspect_ratio_range_;
433 float min_object_covered_;
434 bool use_image_if_no_bounding_boxes_;
435 };
436
437 template <typename T>
438 class StatefulSampleDistortedBoundingBoxOp
439 : public SampleDistortedBoundingBoxBaseOp<T> {
440 public:
StatefulSampleDistortedBoundingBoxOp(OpKernelConstruction * context)441 explicit StatefulSampleDistortedBoundingBoxOp(OpKernelConstruction* context)
442 : SampleDistortedBoundingBoxBaseOp<T>(context) {
443 OP_REQUIRES_OK(context, generator_.Init(context));
444 }
445
Compute(OpKernelContext * context)446 void Compute(OpKernelContext* context) override {
447 // Need to reserve samples since `generator_` is shared.
448 this->DoCompute(context,
449 generator_.ReserveSamples32(4 * this->max_attempts_));
450 }
451
452 private:
453 GuardedPhiloxRandom generator_;
454 };
455
456 template <typename T>
457 class StatelessSampleDistortedBoundingBoxOp
458 : public SampleDistortedBoundingBoxBaseOp<T> {
459 public:
StatelessSampleDistortedBoundingBoxOp(OpKernelConstruction * context)460 explicit StatelessSampleDistortedBoundingBoxOp(OpKernelConstruction* context)
461 : SampleDistortedBoundingBoxBaseOp<T>(context) {}
462
Compute(OpKernelContext * context)463 void Compute(OpKernelContext* context) override {
464 const Tensor& seed_t = context->input(3);
465 OP_REQUIRES(context, seed_t.dims() == 1 && seed_t.dim_size(0) == 2,
466 errors::InvalidArgument("seed must have shape [2], not ",
467 seed_t.shape().DebugString()));
468
469 // Create and initialize stateless random number generator (rng).
470 // There is no need to `Skip` (or reserve) samples since the scope of this
471 // rng is local.
472 random::PhiloxRandom::Key key;
473 random::PhiloxRandom::ResultType counter;
474 OP_REQUIRES_OK(context, GenerateKey(seed_t, &key, &counter));
475
476 this->DoCompute(context, random::PhiloxRandom(counter, key));
477 }
478 };
479
480 #define REGISTER_KERNELS(type) \
481 REGISTER_KERNEL_BUILDER(Name("SampleDistortedBoundingBox") \
482 .Device(DEVICE_CPU) \
483 .TypeConstraint<type>("T"), \
484 StatefulSampleDistortedBoundingBoxOp<type>) \
485 REGISTER_KERNEL_BUILDER(Name("SampleDistortedBoundingBoxV2") \
486 .Device(DEVICE_CPU) \
487 .TypeConstraint<type>("T"), \
488 StatefulSampleDistortedBoundingBoxOp<type>) \
489 REGISTER_KERNEL_BUILDER(Name("StatelessSampleDistortedBoundingBox") \
490 .Device(DEVICE_CPU) \
491 .TypeConstraint<type>("T"), \
492 StatelessSampleDistortedBoundingBoxOp<type>)
493
494 TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS);
495 #undef REGISTER_KERNELS
496
497 } // namespace tensorflow
498