• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // See docs in ../ops/image_ops.cc.
16 #include <math.h>
17 
18 #include <cmath>
19 
20 #include "tensorflow/core/framework/bounds_check.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/register_types.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/kernels/stateless_random_ops.h"
26 #include "tensorflow/core/lib/random/philox_random.h"
27 #include "tensorflow/core/lib/random/simple_philox.h"
28 #include "tensorflow/core/util/guarded_philox_random.h"
29 
30 using tensorflow::random::SimplePhilox;
31 
32 namespace tensorflow {
33 namespace {
34 
35 // A simple Rectangle class that supplies intersection.
36 class Rectangle {
37  public:
Rectangle()38   Rectangle() { Set(0, 0, 0, 0); }
Rectangle(int xmin,int ymin,int xmax,int ymax)39   Rectangle(int xmin, int ymin, int xmax, int ymax) {
40     Set(xmin, ymin, xmax, ymax);
41   }
42 
Set(int xmin,int ymin,int xmax,int ymax)43   void Set(int xmin, int ymin, int xmax, int ymax) {
44     min_x_ = xmin;
45     min_y_ = ymin;
46     max_x_ = xmax;
47     max_y_ = ymax;
48   }
49 
IsEmpty() const50   bool IsEmpty() const { return min_x_ > max_x_ || min_y_ > max_y_; }
Area() const51   float Area() const {
52     return static_cast<float>((max_x_ - min_x_) * (max_y_ - min_y_));
53   }
54 
Intersect(const Rectangle & r) const55   Rectangle Intersect(const Rectangle& r) const {
56     const int pmin_x = std::max(min_x_, r.min_x_);
57     const int pmin_y = std::max(min_y_, r.min_y_);
58     const int pmax_x = std::min(max_x_, r.max_x_);
59     const int pmax_y = std::min(max_y_, r.max_y_);
60 
61     if (pmin_x > pmax_x || pmin_y > pmax_y) {
62       return Rectangle();
63     } else {
64       return Rectangle(pmin_x, pmin_y, pmax_x, pmax_y);
65     }
66   }
67 
68   int min_x_;
69   int min_y_;
70   int max_x_;
71   int max_y_;
72 };
73 
74 // Determine if the supplied cropping box covers a sufficient fraction of the
75 // the supplied bounding boxes.
SatisfiesOverlapConstraints(const Rectangle & crop,float minimum_object_covered,const std::vector<Rectangle> & bounding_boxes)76 bool SatisfiesOverlapConstraints(const Rectangle& crop,
77                                  float minimum_object_covered,
78                                  const std::vector<Rectangle>& bounding_boxes) {
79   // Reject any bounding box which contains no pixels.
80   const float kMinArea = 1.0;
81   if (crop.Area() < kMinArea) {
82     return false;
83   }
84 
85   // Loop through all objects and determine if the proposed cropping box covers
86   // a sufficient fraction of one of the supplied bounding boxes.
87   bool is_object_covered = false;
88   for (const auto& bbox : bounding_boxes) {
89     const float object_area = bbox.Area();
90     if (object_area < kMinArea) {
91       continue;
92     }
93 
94     const float object_covered = crop.Intersect(bbox).Area() / object_area;
95 
96     if (object_covered >= minimum_object_covered) {
97       is_object_covered = true;
98       break;
99     }
100   }
101   return is_object_covered;
102 }
103 
104 // Generate a random crop within the rectangle
105 // (0, 0, original_width, original_height).
106 // The minimum area of the crop will be between
107 //   min_relative_crop_area * orig_width * orig_height
108 // and
109 //   max_relative_crop_area * orig_width * orig_height
110 // such that its width = round(aspect_ratio * height).
111 // The diameter of the generated rectangle will be uniformly distributed between
112 // its minimum and maximum size. The center of the rectangle will be distributed
113 // uniformly within the source rectangle. The function returns false if the
114 // rectangle could not be generated with the given constraints.
GenerateRandomCrop(int original_width,int original_height,float min_relative_crop_area,float max_relative_crop_area,float aspect_ratio,SimplePhilox * random,Rectangle * crop_rect)115 bool GenerateRandomCrop(int original_width, int original_height,
116                         float min_relative_crop_area,
117                         float max_relative_crop_area, float aspect_ratio,
118                         SimplePhilox* random, Rectangle* crop_rect) {
119   if (max_relative_crop_area <= 0.0 || aspect_ratio <= 0.0 ||
120       original_width <= 0 || original_height <= 0 ||
121       min_relative_crop_area > max_relative_crop_area) {
122     return false;
123   }
124 
125   const float min_area =
126       min_relative_crop_area * original_width * original_height;
127   const float max_area =
128       max_relative_crop_area * original_width * original_height;
129 
130   int height = static_cast<int>(lrintf(std::sqrt(min_area / aspect_ratio)));
131   int max_height = static_cast<int>(lrintf(std::sqrt(max_area / aspect_ratio)));
132 
133   // TODO(b/140767341): Rewrite the generation logic to be more tolerant
134   // of floating point behavior.
135   if (lrintf(max_height * aspect_ratio) > original_width) {
136     // We must find the smallest max_height satisfying
137     // round(max_height * aspect_ratio) <= original_width:
138     const float kEps = 0.0000001;
139     max_height = static_cast<int>((original_width + 0.5 - kEps) / aspect_ratio);
140     // If due some precision issues, we still cannot guarantee
141     // round(max_height * aspect_ratio) <= original_width, subtract 1 from
142     // max height.
143     if (lrintf(max_height * aspect_ratio) > original_width) {
144       max_height -= 1;
145     }
146   }
147 
148   if (max_height > original_height) {
149     max_height = original_height;
150   }
151 
152   if (height >= max_height) {
153     height = max_height;
154   }
155 
156   if (height < max_height) {
157     // We need to generate a random number in the closed range
158     // [0, max_height - height].
159     height += random->Uniform(max_height - height + 1);
160   }
161   int width = static_cast<int>(lrintf(height * aspect_ratio));
162   DCHECK_LE(width, original_width);
163 
164   // Let us not fail if rounding error causes the area to be
165   // outside the constraints.
166   // Try first with a slightly bigger rectangle first.
167   float area = static_cast<float>(width * height);
168   if (area < min_area) {
169     height += 1;
170     width = static_cast<int>(lrintf(height * aspect_ratio));
171     area = width * height;
172   }
173 
174   // Let us not fail if rounding error causes the area to be
175   // outside the constraints.
176   // Try first with a slightly smaller rectangle first.
177   if (area > max_area) {
178     height -= 1;
179     width = static_cast<int>(lrintf(height * aspect_ratio));
180     area = width * height;
181   }
182 
183   // Now, we explored all options to rectify small rounding errors.
184   // It seems the constraints can't be satisfied: return false.
185   if (area < min_area || area > max_area || width > original_width ||
186       height > original_height || width <= 0 || height <= 0) {
187     return false;
188   }
189 
190   int y = 0;
191   if (height < original_height) {
192     y = random->Uniform(original_height - height);
193   }
194   int x = 0;
195   if (width < original_width) {
196     x = random->Uniform(original_width - width);
197   }
198 
199   crop_rect->min_x_ = x;
200   crop_rect->min_y_ = y;
201   crop_rect->max_x_ = x + width;
202   crop_rect->max_y_ = y + height;
203   return true;
204 }
205 }  // namespace
206 
207 template <typename T>
208 class SampleDistortedBoundingBoxBaseOp : public OpKernel {
209  public:
SampleDistortedBoundingBoxBaseOp(OpKernelConstruction * context)210   explicit SampleDistortedBoundingBoxBaseOp(OpKernelConstruction* context)
211       : OpKernel(context) {
212     if (context->num_inputs() == 2) {
213       OP_REQUIRES_OK(context, context->GetAttr("min_object_covered",
214                                                &min_object_covered_));
215       OP_REQUIRES(
216           context, min_object_covered_ >= 0,
217           errors::InvalidArgument("Min object covered must be non-negative: ",
218                                   min_object_covered_));
219     }
220 
221     OP_REQUIRES_OK(context, context->GetAttr("use_image_if_no_bounding_boxes",
222                                              &use_image_if_no_bounding_boxes_));
223 
224     OP_REQUIRES_OK(
225         context, context->GetAttr("aspect_ratio_range", &aspect_ratio_range_));
226     OP_REQUIRES(context, aspect_ratio_range_.size() == 2,
227                 errors::InvalidArgument(
228                     "Aspect ratio range field must specify 2 dimensions"));
229 
230     OP_REQUIRES(
231         context, aspect_ratio_range_[0] > 0 && aspect_ratio_range_[1] > 0,
232         errors::InvalidArgument("Aspect ratio range must be non-negative: [",
233                                 aspect_ratio_range_[0], ", ",
234                                 aspect_ratio_range_[1], "]"));
235 
236     OP_REQUIRES_OK(context, context->GetAttr("area_range", &area_range_));
237     OP_REQUIRES(
238         context, area_range_.size() == 2,
239         errors::InvalidArgument("Area range field must specify 2 dimensions"));
240 
241     OP_REQUIRES(
242         context, area_range_[0] > 0 && area_range_[1] > 0,
243         errors::InvalidArgument("Area range must be non-negative: [",
244                                 area_range_[0], ", ", area_range_[1], "]"));
245 
246     OP_REQUIRES(context, area_range_[0] <= 1 && area_range_[1] <= 1,
247                 errors::InvalidArgument(
248                     "Area range must be less then or equal to 1.0: [",
249                     area_range_[0], ", ", area_range_[1], "]"));
250 
251     OP_REQUIRES_OK(context, context->GetAttr("max_attempts", &max_attempts_));
252     OP_REQUIRES(context, max_attempts_ > 0,
253                 errors::InvalidArgument("Max attempts must be non-negative: ",
254                                         max_attempts_));
255   }
256 
DoCompute(OpKernelContext * context,const random::PhiloxRandom & rng)257   void DoCompute(OpKernelContext* context, const random::PhiloxRandom& rng) {
258     const Tensor& image_size = context->input(0);
259 
260     OP_REQUIRES(context, image_size.dims() == 1,
261                 errors::InvalidArgument("image_size must be 1-dimensional",
262                                         image_size.shape().DebugString()));
263     OP_REQUIRES(context, image_size.dim_size(0) == 3,
264                 errors::InvalidArgument("image_size must contain 3 elements",
265                                         image_size.shape().DebugString()));
266 
267     // Note image_size_data(2) is the depth and unused.
268     const uint64 height_raw = internal::SubtleMustCopy(image_size.flat<T>()(0));
269     const uint64 width_raw = internal::SubtleMustCopy(image_size.flat<T>()(1));
270     OP_REQUIRES(context,
271                 FastBoundsCheck(height_raw, std::numeric_limits<int32>::max()),
272                 errors::InvalidArgument("image height cannot be >= int32 max"));
273     OP_REQUIRES(context,
274                 FastBoundsCheck(width_raw, std::numeric_limits<int32>::max()),
275                 errors::InvalidArgument("image width cannot be >= int32 max"));
276     const int32_t height = static_cast<int32>(height_raw);
277     const int32_t width = static_cast<int32>(width_raw);
278 
279     // Ensure that the supplied bounding boxes are sane and convert them to
280     // Rectangles.
281     const Tensor& input_boxes = context->input(1);
282     OP_REQUIRES(context, input_boxes.dims() == 3,
283                 errors::InvalidArgument("input boxes must be 3-dimensional "
284                                         "[batch, num_boxes, coords]: ",
285                                         input_boxes.shape().DebugString()));
286     OP_REQUIRES(context, input_boxes.dim_size(input_boxes.dims() - 1) == 4,
287                 errors::InvalidArgument(
288                     "bounding boxes must have shape [4] or [*, 4], got ",
289                     input_boxes.shape().DebugString()));
290 
291     float min_object_covered_val = 0.0;
292     // `SampleDistortedBoundingBox` op accepts 2 inputs and has
293     // `min_object_covered` as an attribute (handled in the constructor).
294     // `SampleDistortedBoundingBoxV2` and `StatelessSampleDistortedBoundingBox`
295     // ops accept 3+ inputs, including `min_object_covered`.
296     if (context->num_inputs() >= 3) {
297       const Tensor& min_object_covered = context->input(2);
298 
299       OP_REQUIRES(
300           context, TensorShapeUtils::IsScalar(min_object_covered.shape()),
301           errors::InvalidArgument("min_object_covered must be 0-D, got shape ",
302                                   min_object_covered.shape().DebugString()));
303 
304       min_object_covered_val = min_object_covered.scalar<float>()();
305 
306       OP_REQUIRES(
307           context, min_object_covered_val >= 0,
308           errors::InvalidArgument("Min object covered must be non-negative: ",
309                                   min_object_covered_val));
310     } else {
311       min_object_covered_val = min_object_covered_;
312     }
313 
314     std::vector<Rectangle> bounding_boxes;
315     if (input_boxes.NumElements() > 0) {
316       TTypes<float>::ConstMatrix boxes = input_boxes.flat_inner_dims<float>();
317       for (int b = 0; b < boxes.dimension(0); ++b) {
318         for (int i = 0; i < 4; ++i) {
319           OP_REQUIRES(
320               context, boxes(b, i) >= 0.0 && boxes(b, i) <= 1.0,
321               errors::InvalidArgument("All bounding box coordinates must "
322                                       "be in [0.0, 1.0]: ",
323                                       boxes(b, i)));
324         }
325 
326         const int32_t x_min = static_cast<int32>(boxes(b, 1) * width);
327         const int32_t y_min = static_cast<int32>(boxes(b, 0) * height);
328         const int32_t x_max = static_cast<int32>(boxes(b, 3) * width);
329         const int32_t y_max = static_cast<int32>(boxes(b, 2) * height);
330 
331         bounding_boxes.push_back(Rectangle(x_min, y_min, x_max, y_max));
332       }
333     }
334 
335     // Insert the entire image if no bounding boxes are supplied.
336     const Rectangle image_rect(0, 0, width, height);
337     if (bounding_boxes.empty()) {
338       OP_REQUIRES(context, use_image_if_no_bounding_boxes_,
339                   errors::InvalidArgument(
340                       "No bounding boxes provided as input. One must "
341                       "enable use_image_if_no_bounding_boxes if you wish "
342                       "to not provide any bounding boxes."));
343       bounding_boxes.push_back(image_rect);
344     }
345 
346     const float min_sample_area = area_range_[0];
347     const float max_sample_area = area_range_[1];
348     const float min_sample_aspect_ratio = aspect_ratio_range_[0];
349     const float max_sample_aspect_ratio = aspect_ratio_range_[1];
350 
351     auto local_rng = rng;
352     random::SimplePhilox random(&local_rng);
353 
354     Rectangle crop_rect;
355     bool sample_generated = false;
356     for (int i = 0; i < max_attempts_; ++i) {
357       const float sample_aspect_ratio =
358           random.RandFloat() *
359               (max_sample_aspect_ratio - min_sample_aspect_ratio) +
360           min_sample_aspect_ratio;
361 
362       if (GenerateRandomCrop(width, height, min_sample_area, max_sample_area,
363                              sample_aspect_ratio, &random, &crop_rect)) {
364         if (SatisfiesOverlapConstraints(crop_rect, min_object_covered_val,
365                                         bounding_boxes)) {
366           sample_generated = true;
367           break;
368         }
369       }
370     }
371 
372     if (!sample_generated) {
373       crop_rect = image_rect;
374     }
375 
376     // Determine the cropping parameters from the bounding box.
377     const int target_width = crop_rect.max_x_ - crop_rect.min_x_;
378     const int target_height = crop_rect.max_y_ - crop_rect.min_y_;
379 
380     const int offset_width = crop_rect.min_x_;
381     const int offset_height = crop_rect.min_y_;
382 
383     // Ensure that the bounding box fits in the image dimensions.
384     OP_REQUIRES(context, width >= target_width + offset_width,
385                 errors::FailedPrecondition(
386                     "width must be > target_width + offset_width: ", width,
387                     "vs ", target_width, " + ", offset_width));
388     OP_REQUIRES(context, height >= target_height + offset_height,
389                 errors::FailedPrecondition(
390                     "height must be >= target_height: height = ", height, "vs ",
391                     target_height, " + ", offset_height));
392 
393     // Create two vectors, each 3 elements, to provide as arguments to Slice.
394     // See Slice() operation for details.
395     Tensor* begin = nullptr;
396     OP_REQUIRES_OK(context,
397                    context->allocate_output(0, TensorShape({3}), &begin));
398     Tensor* size = nullptr;
399     OP_REQUIRES_OK(context,
400                    context->allocate_output(1, TensorShape({3}), &size));
401     Tensor* bboxes = nullptr;
402     OP_REQUIRES_OK(
403         context, context->allocate_output(2, TensorShape({1, 1, 4}), &bboxes));
404 
405     typename TTypes<T, 1>::Tensor begin_data(begin->tensor<T, 1>());
406     typename TTypes<T, 1>::Tensor size_data(size->tensor<T, 1>());
407     TTypes<float, 3>::Tensor bboxes_data = bboxes->tensor<float, 3>();
408 
409     begin_data(0) = T(offset_height);
410     size_data(0) = T(target_height);
411 
412     begin_data(1) = T(offset_width);
413     size_data(1) = T(target_width);
414 
415     bboxes_data(0, 0, 0) =
416         static_cast<float>(crop_rect.min_y_) / static_cast<float>(height);
417     bboxes_data(0, 0, 1) =
418         static_cast<float>(crop_rect.min_x_) / static_cast<float>(width);
419     bboxes_data(0, 0, 2) =
420         static_cast<float>(crop_rect.max_y_) / static_cast<float>(height);
421     bboxes_data(0, 0, 3) =
422         static_cast<float>(crop_rect.max_x_) / static_cast<float>(width);
423 
424     // Retain all of the channels.
425     begin_data(2) = T(0);
426     size_data(2) = T(-1);
427   }
428 
429  protected:
430   int32 max_attempts_;
431   std::vector<float> area_range_;
432   std::vector<float> aspect_ratio_range_;
433   float min_object_covered_;
434   bool use_image_if_no_bounding_boxes_;
435 };
436 
437 template <typename T>
438 class StatefulSampleDistortedBoundingBoxOp
439     : public SampleDistortedBoundingBoxBaseOp<T> {
440  public:
StatefulSampleDistortedBoundingBoxOp(OpKernelConstruction * context)441   explicit StatefulSampleDistortedBoundingBoxOp(OpKernelConstruction* context)
442       : SampleDistortedBoundingBoxBaseOp<T>(context) {
443     OP_REQUIRES_OK(context, generator_.Init(context));
444   }
445 
Compute(OpKernelContext * context)446   void Compute(OpKernelContext* context) override {
447     // Need to reserve samples since `generator_` is shared.
448     this->DoCompute(context,
449                     generator_.ReserveSamples32(4 * this->max_attempts_));
450   }
451 
452  private:
453   GuardedPhiloxRandom generator_;
454 };
455 
456 template <typename T>
457 class StatelessSampleDistortedBoundingBoxOp
458     : public SampleDistortedBoundingBoxBaseOp<T> {
459  public:
StatelessSampleDistortedBoundingBoxOp(OpKernelConstruction * context)460   explicit StatelessSampleDistortedBoundingBoxOp(OpKernelConstruction* context)
461       : SampleDistortedBoundingBoxBaseOp<T>(context) {}
462 
Compute(OpKernelContext * context)463   void Compute(OpKernelContext* context) override {
464     const Tensor& seed_t = context->input(3);
465     OP_REQUIRES(context, seed_t.dims() == 1 && seed_t.dim_size(0) == 2,
466                 errors::InvalidArgument("seed must have shape [2], not ",
467                                         seed_t.shape().DebugString()));
468 
469     // Create and initialize stateless random number generator (rng).
470     // There is no need to `Skip` (or reserve) samples since the scope of this
471     // rng is local.
472     random::PhiloxRandom::Key key;
473     random::PhiloxRandom::ResultType counter;
474     OP_REQUIRES_OK(context, GenerateKey(seed_t, &key, &counter));
475 
476     this->DoCompute(context, random::PhiloxRandom(counter, key));
477   }
478 };
479 
480 #define REGISTER_KERNELS(type)                                        \
481   REGISTER_KERNEL_BUILDER(Name("SampleDistortedBoundingBox")          \
482                               .Device(DEVICE_CPU)                     \
483                               .TypeConstraint<type>("T"),             \
484                           StatefulSampleDistortedBoundingBoxOp<type>) \
485   REGISTER_KERNEL_BUILDER(Name("SampleDistortedBoundingBoxV2")        \
486                               .Device(DEVICE_CPU)                     \
487                               .TypeConstraint<type>("T"),             \
488                           StatefulSampleDistortedBoundingBoxOp<type>) \
489   REGISTER_KERNEL_BUILDER(Name("StatelessSampleDistortedBoundingBox") \
490                               .Device(DEVICE_CPU)                     \
491                               .TypeConstraint<type>("T"),             \
492                           StatelessSampleDistortedBoundingBoxOp<type>)
493 
494 TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS);
495 #undef REGISTER_KERNELS
496 
497 }  // namespace tensorflow
498