• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/image_ops.cc
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/image/crop_and_resize_op.h"
21 
22 #include <functional>
23 #include <string>
24 
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/core/framework/bounds_check.h"
27 #include "tensorflow/core/framework/register_types.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_reference.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/status.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/types.h"
36 #include "tensorflow/core/util/determinism.h"
37 #include "tensorflow/core/util/work_sharder.h"
38 
39 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
40 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
41 #include "tensorflow/core/platform/stream_executor.h"
42 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
43 
44 #if GOOGLE_CUDA
45 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
46 using stream_executor::cuda::ScopedActivateExecutorContext;
47 #elif TENSORFLOW_USE_ROCM
48 #include "tensorflow/core/platform/rocm.h"
49 using stream_executor::rocm::ScopedActivateExecutorContext;
50 #endif
51 
52 namespace tensorflow {
53 namespace {
54 
55 typedef Eigen::ThreadPoolDevice CPUDevice;
56 typedef Eigen::GpuDevice GPUDevice;
57 using Callback = std::function<void()>;
58 
ParseAndCheckBoxSizes(const Tensor & boxes,const Tensor & box_index,int * num_boxes)59 static inline Status ParseAndCheckBoxSizes(const Tensor& boxes,
60                                            const Tensor& box_index,
61                                            int* num_boxes) {
62   if (boxes.NumElements() == 0 && box_index.NumElements() == 0) {
63     *num_boxes = 0;
64     return Status::OK();
65   }
66   // The shape of 'boxes' is [num_boxes, 4].
67   if (boxes.dims() != 2) {
68     return errors::InvalidArgument("boxes must be 2-D",
69                                    boxes.shape().DebugString());
70   }
71   *num_boxes = boxes.dim_size(0);
72   if (boxes.dim_size(1) != 4) {
73     return errors::InvalidArgument("boxes must have 4 columns");
74   }
75   // The shape of 'box_index' is [num_boxes].
76   if (box_index.dims() != 1) {
77     return errors::InvalidArgument("box_index must be 1-D",
78                                    box_index.shape().DebugString());
79   }
80   if (box_index.dim_size(0) != *num_boxes) {
81     return errors::InvalidArgument("box_index has incompatible shape");
82   }
83   return Status::OK();
84 }
85 
86 // Conditionally calls the compute callback if all values in box_index are in
87 // [0, batch_size) then calls done.
88 template <typename Device>
89 inline void RunIfBoxIndexIsValid(
90     OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
91     int batch_size, const Callback& compute, const Callback& done);
92 
93 // Specialization of CheckValidBoxIndex for a CPUDevice.
94 template <>
RunIfBoxIndexIsValid(OpKernelContext * context,typename TTypes<int32,1>::ConstTensor box_index,int batch_size,const Callback & compute,const Callback & done)95 inline void RunIfBoxIndexIsValid<CPUDevice>(
96     OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
97     int batch_size, const Callback& compute, const Callback& done) {
98   const int num_boxes = box_index.dimension(0);
99   for (int b = 0; b < num_boxes; ++b) {
100     OP_REQUIRES_ASYNC(
101         context, FastBoundsCheck(box_index(b), batch_size),
102         errors::OutOfRange("box_index has values outside [0, batch_size)"),
103         done);
104   }
105   if (compute) {
106     compute();
107   }
108   if (done) {
109     done();
110   }
111 }
112 
113 }  // namespace
114 
115 template <typename Device, typename T>
116 class CropAndResizeOp : public AsyncOpKernel {
117  public:
CropAndResizeOp(OpKernelConstruction * context)118   explicit CropAndResizeOp(OpKernelConstruction* context)
119       : AsyncOpKernel(context) {
120     OP_REQUIRES_OK(context, context->GetAttr("method", &method_));
121     OP_REQUIRES(context, method_ == "bilinear" || method_ == "nearest",
122                 errors::InvalidArgument(
123                     "method must be 'bilinear' or 'nearest'", method_));
124     OP_REQUIRES_OK(context, context->GetAttr("extrapolation_value",
125                                              &extrapolation_value_));
126   }
127 
ComputeAsync(OpKernelContext * context,DoneCallback done)128   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
129     // The shape of 'image' is [batch_size, image_height, image_width,
130     // channels].
131     const Tensor& image = context->input(0);
132     // The shape of 'boxes' is [num_boxes, 4].
133     const Tensor& boxes = context->input(1);
134     // The shape of 'box_index' is [num_boxes].
135     const Tensor& box_index = context->input(2);
136     // The shape of 'crop_size' is [2].
137     const Tensor& crop_size = context->input(3);
138 
139     // Validate inputs dimensions.
140     OP_REQUIRES_ASYNC(context, image.dims() == 4,
141                       errors::InvalidArgument("input image must be 4-D",
142                                               image.shape().DebugString()),
143                       done);
144     const int batch_size = image.dim_size(0);
145     const int image_height = image.dim_size(1);
146     const int image_width = image.dim_size(2);
147     const int depth = image.dim_size(3);
148     OP_REQUIRES_ASYNC(
149         context, image_height > 0 && image_width > 0,
150         errors::InvalidArgument("image dimensions must be positive"), done);
151     int num_boxes = 0;
152     OP_REQUIRES_OK_ASYNC(
153         context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
154 
155     OP_REQUIRES_ASYNC(context, crop_size.dims() == 1,
156                       errors::InvalidArgument("crop_size must be 1-D",
157                                               crop_size.shape().DebugString()),
158                       done);
159     OP_REQUIRES_ASYNC(
160         context, crop_size.dim_size(0) == 2,
161         errors::InvalidArgument("crop_size must have two elements",
162                                 crop_size.shape().DebugString()),
163         done);
164 
165     // Copy and validate crop sizes.
166     auto crop_size_vec = crop_size.vec<int32>();
167     const int crop_height = internal::SubtleMustCopy(crop_size_vec(0));
168     const int crop_width = internal::SubtleMustCopy(crop_size_vec(1));
169     OP_REQUIRES_ASYNC(
170         context, crop_height > 0 && crop_width > 0,
171         errors::InvalidArgument("crop dimensions must be positive"), done);
172 
173     // Allocate output tensor.
174     Tensor* output = nullptr;
175     OP_REQUIRES_OK_ASYNC(
176         context,
177         context->allocate_output(
178             0, TensorShape({num_boxes, crop_height, crop_width, depth}),
179             &output),
180         done);
181 
182     auto compute_callback = [this, context, output]() {
183       const Tensor& image = context->input(0);
184       const Tensor& boxes = context->input(1);
185       const Tensor& box_index = context->input(2);
186       const bool status = functor::CropAndResize<Device, T>()(
187           context, image.tensor<T, 4>(), boxes.tensor<float, 2>(),
188           box_index.tensor<int32, 1>(), method_, extrapolation_value_,
189           output->tensor<float, 4>());
190 
191       if (!status) {
192         context->SetStatus(
193             errors::Internal("Failed to launch CropAndResizeKernel."));
194       }
195     };
196 
197     RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
198                                  batch_size, std::move(compute_callback),
199                                  std::move(done));
200   }
201 
202  private:
203   float extrapolation_value_;
204   string method_;
205 };
206 
207 // Partial specialization of CropAndResize functor for a CPUDevice.
208 namespace functor {
209 template <typename T>
210 struct CropAndResize<CPUDevice, T> {
operator ()tensorflow::functor::CropAndResize211   bool operator()(OpKernelContext* context,
212                   typename TTypes<T, 4>::ConstTensor image,
213                   typename TTypes<float, 2>::ConstTensor boxes,
214                   typename TTypes<int32, 1>::ConstTensor box_index,
215                   const string& method_name, float extrapolation_value,
216                   typename TTypes<float, 4>::Tensor crops) {
217     const int batch_size = image.dimension(0);
218     const int image_height = image.dimension(1);
219     const int image_width = image.dimension(2);
220 
221     const int num_boxes = crops.dimension(0);
222     const int crop_height = crops.dimension(1);
223     const int crop_width = crops.dimension(2);
224     const int depth = crops.dimension(3);
225 
226     // Since `functor::CropAndResize` operates on float, we first validate
227     // that we don't overflow (since overflow causes undefined behavior which
228     // could result in segfault in this scenario).
229     const Eigen::Tensor<bool, 0, Eigen::RowMajor> only_finite_elements =
230         boxes.isfinite().all();
231     if (!only_finite_elements()) {
232       context->SetStatus(errors::InvalidArgument(
233           "Boxes contains at least one element that is not finite"));
234       return false;
235     }
236 
237     // Sharding across boxes.
238     auto CropAndResizePerBox = [&](int64_t start_box, int64_t limit_box) {
239       for (int b = start_box; b < limit_box; ++b) {
240         const float y1 = boxes(b, 0);
241         const float x1 = boxes(b, 1);
242         const float y2 = boxes(b, 2);
243         const float x2 = boxes(b, 3);
244 
245         const int32_t b_in = box_index(b);
246         if (!FastBoundsCheck(b_in, batch_size)) {
247           continue;
248         }
249 
250         const float height_scale =
251             (crop_height > 1)
252                 ? (y2 - y1) * (image_height - 1) / (crop_height - 1)
253                 : 0;
254         const float width_scale =
255             (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1)
256                              : 0;
257 
258         for (int y = 0; y < crop_height; ++y) {
259           const float in_y = (crop_height > 1)
260                                  ? y1 * (image_height - 1) + y * height_scale
261                                  : 0.5 * (y1 + y2) * (image_height - 1);
262           if (in_y < 0 || in_y > image_height - 1) {
263             for (int x = 0; x < crop_width; ++x) {
264               for (int d = 0; d < depth; ++d) {
265                 crops(b, y, x, d) = extrapolation_value;
266               }
267             }
268             continue;
269           }
270           if (method_name == "bilinear") {
271             const int top_y_index = floorf(in_y);
272             const int bottom_y_index = ceilf(in_y);
273             const float y_lerp = in_y - top_y_index;
274 
275             for (int x = 0; x < crop_width; ++x) {
276               const float in_x = (crop_width > 1)
277                                      ? x1 * (image_width - 1) + x * width_scale
278                                      : 0.5 * (x1 + x2) * (image_width - 1);
279               if (in_x < 0 || in_x > image_width - 1) {
280                 for (int d = 0; d < depth; ++d) {
281                   crops(b, y, x, d) = extrapolation_value;
282                 }
283                 continue;
284               }
285               const int left_x_index = floorf(in_x);
286               const int right_x_index = ceilf(in_x);
287               const float x_lerp = in_x - left_x_index;
288 
289               for (int d = 0; d < depth; ++d) {
290                 const float top_left(static_cast<float>(
291                     image(b_in, top_y_index, left_x_index, d)));
292                 const float top_right(static_cast<float>(
293                     image(b_in, top_y_index, right_x_index, d)));
294                 const float bottom_left(static_cast<float>(
295                     image(b_in, bottom_y_index, left_x_index, d)));
296                 const float bottom_right(static_cast<float>(
297                     image(b_in, bottom_y_index, right_x_index, d)));
298                 const float top = top_left + (top_right - top_left) * x_lerp;
299                 const float bottom =
300                     bottom_left + (bottom_right - bottom_left) * x_lerp;
301                 crops(b, y, x, d) = top + (bottom - top) * y_lerp;
302               }
303             }
304           } else {  // method == "nearest"
305             for (int x = 0; x < crop_width; ++x) {
306               const float in_x = (crop_width > 1)
307                                      ? x1 * (image_width - 1) + x * width_scale
308                                      : 0.5 * (x1 + x2) * (image_width - 1);
309               if (in_x < 0 || in_x > image_width - 1) {
310                 for (int d = 0; d < depth; ++d) {
311                   crops(b, y, x, d) = extrapolation_value;
312                 }
313                 continue;
314               }
315               const int closest_x_index = roundf(in_x);
316               const int closest_y_index = roundf(in_y);
317               for (int d = 0; d < depth; ++d) {
318                 crops(b, y, x, d) = static_cast<float>(
319                     image(b_in, closest_y_index, closest_x_index, d));
320               }
321             }
322           }
323         }
324       }
325     };
326 
327     // A rough estimation of the cost for each cropped box.
328     double cost_per_pixel =
329         depth * (Eigen::TensorOpCost::AddCost<float>() * 6 +
330                  Eigen::TensorOpCost::MulCost<float>() * 3 +
331                  Eigen::TensorOpCost::CastCost<T, float>() * 4) +
332         (Eigen::TensorOpCost::AddCost<float>() * 2 +
333          Eigen::TensorOpCost::AddCost<float>() * 3);
334     if (method_name == "nearest") {
335       cost_per_pixel = depth * Eigen::TensorOpCost::CastCost<T, float>() +
336                        Eigen::TensorOpCost::AddCost<float>() * 4 +
337                        Eigen::TensorOpCost::MulCost<float>() * 4;
338     }
339     const double cost_per_box = crop_height * crop_width * cost_per_pixel;
340 
341     const DeviceBase::CpuWorkerThreads& worker_threads =
342         *(context->device()->tensorflow_cpu_worker_threads());
343     Shard(worker_threads.num_threads, worker_threads.workers, num_boxes,
344           cost_per_box, CropAndResizePerBox);
345 
346     return true;
347   }
348 };
349 
350 }  // namespace functor
351 
352 template <typename Device, typename T>
353 class CropAndResizeGradImageOp : public AsyncOpKernel {
354  public:
CropAndResizeGradImageOp(OpKernelConstruction * context)355   explicit CropAndResizeGradImageOp(OpKernelConstruction* context)
356       : AsyncOpKernel(context) {
357     OP_REQUIRES_OK(context, context->GetAttr("method", &method_));
358     OP_REQUIRES(context, method_ == "bilinear" || method_ == "nearest",
359                 errors::InvalidArgument(
360                     "method must be 'bilinear' or 'nearest'", method_));
361   }
362 
ComputeAsync(OpKernelContext * context,DoneCallback done)363   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
364     // The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
365     const Tensor& grads = context->input(0);
366     // The shape of 'boxes' is [num_boxes, 4].
367     const Tensor& boxes = context->input(1);
368     // The shape of 'box_index' is [num_boxes].
369     const Tensor& box_index = context->input(2);
370     // The shape of 'image_size' is [4].
371     const Tensor& image_size = context->input(3);
372 
373     // Validate input shapes.
374     OP_REQUIRES_ASYNC(context, grads.dims() == 4,
375                       errors::InvalidArgument("grads image must be 4-D",
376                                               grads.shape().DebugString()),
377                       done);
378     const int crop_height = grads.dim_size(1);
379     const int crop_width = grads.dim_size(2);
380     OP_REQUIRES_ASYNC(
381         context, crop_height > 0 && crop_width > 0,
382         errors::InvalidArgument("grads dimensions must be positive"), done);
383     int num_boxes = 0;
384     OP_REQUIRES_OK_ASYNC(
385         context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
386     OP_REQUIRES_ASYNC(
387         context, grads.dim_size(0) == num_boxes,
388         errors::InvalidArgument("boxes and grads have incompatible shape"),
389         done);
390 
391     OP_REQUIRES_ASYNC(context, image_size.dims() == 1,
392                       errors::InvalidArgument("image_size must be 1-D",
393                                               image_size.shape().DebugString()),
394                       done);
395     OP_REQUIRES_ASYNC(context, image_size.dim_size(0) == 4,
396                       errors::InvalidArgument("image_size must have 4 elements",
397                                               image_size.shape().DebugString()),
398                       done);
399     auto image_size_vec = image_size.vec<int32>();
400     const int batch_size = internal::SubtleMustCopy(image_size_vec(0));
401     const int image_height = internal::SubtleMustCopy(image_size_vec(1));
402     const int image_width = internal::SubtleMustCopy(image_size_vec(2));
403     const int depth = internal::SubtleMustCopy(image_size_vec(3));
404     OP_REQUIRES_ASYNC(
405         context, image_height > 0 && image_width > 0,
406         errors::InvalidArgument("image dimensions must be positive"), done);
407     OP_REQUIRES_ASYNC(
408         context, grads.dim_size(3) == depth,
409         errors::InvalidArgument("image_size and grads are incompatible"), done);
410 
411     if (std::is_same<Device, GPUDevice>::value) {
412       OP_REQUIRES_ASYNC(
413           context, !OpDeterminismRequired(),
414           errors::Unimplemented(
415               "Deterministic GPU implementation of CropAndResizeBackpropImage"
416               " not available."),
417           done);
418     }
419 
420     // Allocate output tensor.
421     Tensor* output = nullptr;
422     OP_REQUIRES_OK_ASYNC(
423         context,
424         context->allocate_output(
425             0, TensorShape({batch_size, image_height, image_width, depth}),
426             &output),
427         done);
428 
429     auto compute_callback = [this, context, output]() {
430       const Tensor& grads = context->input(0);
431       const Tensor& boxes = context->input(1);
432       const Tensor& box_index = context->input(2);
433       const bool status = functor::CropAndResizeBackpropImage<Device, T>()(
434           context, grads.tensor<float, 4>(), boxes.tensor<float, 2>(),
435           box_index.tensor<int32, 1>(), output->tensor<T, 4>(), method_);
436 
437       if (!status) {
438         context->SetStatus(errors::Internal(
439             "Failed to launch CropAndResizeBackpropImage kernel."));
440       }
441     };
442 
443     RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
444                                  batch_size, std::move(compute_callback),
445                                  std::move(done));
446   }
447 
448  private:
449   string method_;
450 };
451 
452 // Partial specialization of CropAndResizeBackpropImage functor for a CPUDevice.
453 namespace functor {
454 template <typename T>
455 struct CropAndResizeBackpropImage<CPUDevice, T> {
operator ()tensorflow::functor::CropAndResizeBackpropImage456   bool operator()(const OpKernelContext* context,
457                   typename TTypes<float, 4>::ConstTensor grads,
458                   typename TTypes<float, 2>::ConstTensor boxes,
459                   typename TTypes<int32, 1>::ConstTensor box_index,
460                   typename TTypes<T, 4>::Tensor grads_image,
461                   const string& method_name) {
462     const int batch_size = grads_image.dimension(0);
463     const int image_height = grads_image.dimension(1);
464     const int image_width = grads_image.dimension(2);
465 
466     const int num_boxes = grads.dimension(0);
467     const int crop_height = grads.dimension(1);
468     const int crop_width = grads.dimension(2);
469     const int depth = grads.dimension(3);
470 
471     grads_image.setZero();
472 
473     auto CropAndResizeBackImgPerBox = [&](int64_t start_box,
474                                           int64_t limit_box) {
475       for (int b = start_box; b < limit_box; ++b) {
476         const float y1 = boxes(b, 0);
477         const float x1 = boxes(b, 1);
478         const float y2 = boxes(b, 2);
479         const float x2 = boxes(b, 3);
480 
481         const int32_t b_in = box_index(b);
482         if (!FastBoundsCheck(b_in, batch_size)) {
483           continue;
484         }
485 
486         const float height_scale =
487             (crop_height > 1)
488                 ? (y2 - y1) * (image_height - 1) / (crop_height - 1)
489                 : 0;
490         const float width_scale =
491             (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1)
492                              : 0;
493 
494         for (int y = 0; y < crop_height; ++y) {
495           const float in_y = (crop_height > 1)
496                                  ? y1 * (image_height - 1) + y * height_scale
497                                  : 0.5 * (y1 + y2) * (image_height - 1);
498           if (in_y < 0 || in_y > image_height - 1) {
499             continue;
500           }
501           const int top_y_index = floorf(in_y);
502           const int bottom_y_index = ceilf(in_y);
503           const float y_lerp = in_y - top_y_index;
504 
505           for (int x = 0; x < crop_width; ++x) {
506             const float in_x = (crop_width > 1)
507                                    ? x1 * (image_width - 1) + x * width_scale
508                                    : 0.5 * (x1 + x2) * (image_width - 1);
509             if (in_x < 0 || in_x > image_width - 1) {
510               continue;
511             }
512 
513             if (method_name == "bilinear") {
514               const int left_x_index = floorf(in_x);
515               const int right_x_index = ceilf(in_x);
516               const float x_lerp = in_x - left_x_index;
517 
518               for (int d = 0; d < depth; ++d) {
519                 const float dtop = (1 - y_lerp) * grads(b, y, x, d);
520                 grads_image(b_in, top_y_index, left_x_index, d) +=
521                     static_cast<T>((1 - x_lerp) * dtop);
522                 grads_image(b_in, top_y_index, right_x_index, d) +=
523                     static_cast<T>(x_lerp * dtop);
524                 const float dbottom = y_lerp * grads(b, y, x, d);
525                 grads_image(b_in, bottom_y_index, left_x_index, d) +=
526                     static_cast<T>((1 - x_lerp) * dbottom);
527                 grads_image(b_in, bottom_y_index, right_x_index, d) +=
528                     static_cast<T>(x_lerp * dbottom);
529               }
530             } else {  // method_name == "nearest"
531               for (int d = 0; d < depth; ++d) {
532                 int closest_x_index = roundf(in_x);
533                 int closest_y_index = roundf(in_y);
534                 grads_image(b_in, closest_y_index, closest_x_index, d) +=
535                     static_cast<T>(grads(b, y, x, d));
536               }
537             }
538           }
539         }
540       }
541     };
542 
543     // A rough estimation of the cost for each cropped box.
544     // Including calculation cost in the depth loop and pixel loop.
545     const double cost_per_pixel =
546         (method_name == "bilinear"
547              ? depth * (Eigen::TensorOpCost::AddCost<float>() * 7 +
548                         Eigen::TensorOpCost::MulCost<float>() * 6 +
549                         Eigen::TensorOpCost::CastCost<T, float>() * 4) +
550                    Eigen::TensorOpCost::AddCost<float>() * 4
551              : depth * (Eigen::TensorOpCost::AddCost<float>() +
552                         Eigen::TensorOpCost::CastCost<T, float>()) +
553                    Eigen::TensorOpCost::AddCost<float>() * 3);
554 
555     const double cost_per_box = crop_height * crop_width * cost_per_pixel;
556 
557     const DeviceBase::CpuWorkerThreads& worker_threads =
558         *(context->device()->tensorflow_cpu_worker_threads());
559 
560     // Sharding introduces nondeterminism when the gradients associated with
561     // more than two crops backprop into the same element in the source image.
562     int max_threads = OpDeterminismRequired() ? 1 : worker_threads.num_threads;
563 
564     Shard(max_threads, worker_threads.workers, num_boxes, cost_per_box,
565           CropAndResizeBackImgPerBox);
566 
567     return true;
568   }
569 };
570 
571 }  // namespace functor
572 
573 template <typename Device, typename T>
574 class CropAndResizeGradBoxesOp : public AsyncOpKernel {
575  public:
CropAndResizeGradBoxesOp(OpKernelConstruction * context)576   explicit CropAndResizeGradBoxesOp(OpKernelConstruction* context)
577       : AsyncOpKernel(context) {
578     string method;
579     OP_REQUIRES_OK(context, context->GetAttr("method", &method));
580     OP_REQUIRES(context, method == "bilinear",
581                 errors::InvalidArgument("method must be 'bilinear'", method));
582   }
583 
ComputeAsync(OpKernelContext * context,DoneCallback done)584   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
585     // The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
586     const Tensor& grads = context->input(0);
587     // The shape of 'boxes' is [num_boxes, 4].
588     const Tensor& boxes = context->input(2);
589     // The shape of 'box_index' is [num_boxes].
590     const Tensor& box_index = context->input(3);
591     // The shape of 'image' is [batch_size, image_height, image_width, depth].
592     const Tensor& image = context->input(1);
593 
594     // Validate input shapes.
595     OP_REQUIRES_ASYNC(context, grads.dims() == 4,
596                       errors::InvalidArgument("grads image must be 4-D",
597                                               grads.shape().DebugString()),
598                       done);
599     const int crop_height = grads.dim_size(1);
600     const int crop_width = grads.dim_size(2);
601     const int depth = grads.dim_size(3);
602     OP_REQUIRES_ASYNC(
603         context, crop_height > 0 && crop_width > 0,
604         errors::InvalidArgument("grads dimensions must be positive"), done);
605 
606     OP_REQUIRES_ASYNC(context, image.dims() == 4,
607                       errors::InvalidArgument("input image must be 4-D",
608                                               image.shape().DebugString()),
609                       done);
610     const int batch_size = image.dim_size(0);
611     const int image_height = image.dim_size(1);
612     const int image_width = image.dim_size(2);
613     OP_REQUIRES_ASYNC(
614         context, image_height > 0 && image_width > 0,
615         errors::InvalidArgument("image dimensions must be positive"), done);
616     OP_REQUIRES_ASYNC(context, image.dim_size(3) == depth,
617                       errors::InvalidArgument("image, grads depth differ"),
618                       done);
619 
620     int num_boxes = 0;
621     OP_REQUIRES_OK_ASYNC(
622         context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
623 
624     OP_REQUIRES_ASYNC(
625         context, grads.dim_size(0) == num_boxes,
626         errors::InvalidArgument("boxes and grads have incompatible shape"),
627         done);
628 
629     if (std::is_same<Device, GPUDevice>::value) {
630       OP_REQUIRES_ASYNC(
631           context, !OpDeterminismRequired(),
632           errors::Unimplemented(
633               "Deterministic GPU implementation of CropAndResizeBackpropBoxes"
634               " not available."),
635           done);
636     }
637 
638     // Allocate output tensor.
639     Tensor* output = nullptr;
640     OP_REQUIRES_OK_ASYNC(
641         context,
642         context->allocate_output(0, TensorShape({num_boxes, 4}), &output),
643         done);
644 
645     auto compute_callback = [context, output]() {
646       const Tensor& grads = context->input(0);
647       const Tensor& image = context->input(1);
648       const Tensor& boxes = context->input(2);
649       const Tensor& box_index = context->input(3);
650       const bool status = functor::CropAndResizeBackpropBoxes<Device, T>()(
651           context->eigen_device<Device>(), grads.tensor<float, 4>(),
652           image.tensor<T, 4>(), boxes.tensor<float, 2>(),
653           box_index.tensor<int32, 1>(), output->tensor<float, 2>());
654       if (!status) {
655         context->SetStatus(errors::Internal(
656             "Failed to launch CropAndResizeBackpropBoxes kernel."));
657       }
658     };
659 
660     RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
661                                  batch_size, std::move(compute_callback),
662                                  std::move(done));
663   }
664 };
665 
666 // Partial specialization of CropAndResizeBackpropBoxes functor for a CPUDevice.
667 namespace functor {
668 template <typename T>
669 struct CropAndResizeBackpropBoxes<CPUDevice, T> {
operator ()tensorflow::functor::CropAndResizeBackpropBoxes670   bool operator()(const CPUDevice& d,
671                   typename TTypes<float, 4>::ConstTensor grads,
672                   typename TTypes<T, 4>::ConstTensor image,
673                   typename TTypes<float, 2>::ConstTensor boxes,
674                   typename TTypes<int32, 1>::ConstTensor box_index,
675                   typename TTypes<float, 2>::Tensor grads_boxes) {
676     const int batch_size = image.dimension(0);
677     const int image_height = image.dimension(1);
678     const int image_width = image.dimension(2);
679 
680     const int num_boxes = grads.dimension(0);
681     const int crop_height = grads.dimension(1);
682     const int crop_width = grads.dimension(2);
683     const int depth = grads.dimension(3);
684 
685     grads_boxes.setZero();
686 
687     for (int b = 0; b < num_boxes; ++b) {
688       const float y1 = boxes(b, 0);
689       const float x1 = boxes(b, 1);
690       const float y2 = boxes(b, 2);
691       const float x2 = boxes(b, 3);
692 
693       const int32_t b_in = box_index(b);
694       if (!FastBoundsCheck(b_in, batch_size)) {
695         continue;
696       }
697 
698       const float height_ratio =
699           (crop_height > 1)
700               ? static_cast<float>(image_height - 1) / (crop_height - 1)
701               : 0;
702       const float width_ratio =
703           (crop_width > 1)
704               ? static_cast<float>(image_width - 1) / (crop_width - 1)
705               : 0;
706 
707       const float height_scale =
708           (crop_height > 1) ? (y2 - y1) * height_ratio : 0;
709       const float width_scale = (crop_width > 1) ? (x2 - x1) * width_ratio : 0;
710 
711       for (int y = 0; y < crop_height; ++y) {
712         const float in_y = (crop_height > 1)
713                                ? y1 * (image_height - 1) + y * height_scale
714                                : 0.5 * (y1 + y2) * (image_height - 1);
715         if (in_y < 0 || in_y > image_height - 1) {
716           continue;
717         }
718         const int top_y_index = floorf(in_y);
719         const int bottom_y_index = ceilf(in_y);
720         const float y_lerp = in_y - top_y_index;
721 
722         for (int x = 0; x < crop_width; ++x) {
723           const float in_x = (crop_width > 1)
724                                  ? x1 * (image_width - 1) + x * width_scale
725                                  : 0.5 * (x1 + x2) * (image_width - 1);
726           if (in_x < 0 || in_x > image_width - 1) {
727             continue;
728           }
729           const int left_x_index = floorf(in_x);
730           const int right_x_index = ceilf(in_x);
731           const float x_lerp = in_x - left_x_index;
732 
733           for (int d = 0; d < depth; ++d) {
734             const float top_left(
735                 static_cast<float>(image(b_in, top_y_index, left_x_index, d)));
736             const float top_right(
737                 static_cast<float>(image(b_in, top_y_index, right_x_index, d)));
738             const float bottom_left(static_cast<float>(
739                 image(b_in, bottom_y_index, left_x_index, d)));
740             const float bottom_right(static_cast<float>(
741                 image(b_in, bottom_y_index, right_x_index, d)));
742             // Compute the image gradient.
743             float image_grad_y = (1 - x_lerp) * (bottom_left - top_left) +
744                                  x_lerp * (bottom_right - top_right);
745             float image_grad_x = (1 - y_lerp) * (top_right - top_left) +
746                                  y_lerp * (bottom_right - bottom_left);
747             // Modulate the image gradient with the incoming gradient.
748             const float top_grad = grads(b, y, x, d);
749             image_grad_y *= top_grad;
750             image_grad_x *= top_grad;
751             // dy1, dy2
752             if (crop_height > 1) {
753               grads_boxes(b, 0) +=
754                   image_grad_y * (image_height - 1 - y * height_ratio);
755               grads_boxes(b, 2) += image_grad_y * (y * height_ratio);
756             } else {
757               grads_boxes(b, 0) += image_grad_y * 0.5 * (image_height - 1);
758               grads_boxes(b, 2) += image_grad_y * 0.5 * (image_height - 1);
759             }
760             // dx1, dx2
761             if (crop_width > 1) {
762               grads_boxes(b, 1) +=
763                   image_grad_x * (image_width - 1 - x * width_ratio);
764               grads_boxes(b, 3) += image_grad_x * (x * width_ratio);
765             } else {
766               grads_boxes(b, 1) += image_grad_x * 0.5 * (image_width - 1);
767               grads_boxes(b, 3) += image_grad_x * 0.5 * (image_width - 1);
768             }
769           }
770         }
771       }
772     }
773     return true;
774   }
775 };
776 
777 }  // namespace functor
778 
779 #define REGISTER_KERNEL(T)                                \
780   REGISTER_KERNEL_BUILDER(Name("CropAndResize")           \
781                               .Device(DEVICE_CPU)         \
782                               .TypeConstraint<T>("T")     \
783                               .HostMemory("crop_size"),   \
784                           CropAndResizeOp<CPUDevice, T>); \
785                                                           \
786   REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes")  \
787                               .Device(DEVICE_CPU)         \
788                               .TypeConstraint<T>("T"),    \
789                           CropAndResizeGradBoxesOp<CPUDevice, T>);
790 
791 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
792 
793 #undef REGISTER_KERNEL
794 
795 #define REGISTER_KERNEL(T)                               \
796   REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradImage") \
797                               .Device(DEVICE_CPU)        \
798                               .TypeConstraint<T>("T")    \
799                               .HostMemory("image_size"), \
800                           CropAndResizeGradImageOp<CPUDevice, T>);
801 
802 TF_CALL_half(REGISTER_KERNEL);
803 TF_CALL_float(REGISTER_KERNEL);
804 TF_CALL_double(REGISTER_KERNEL);
805 
806 #undef REGISTER_KERNEL
807 
808 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
809 
810 // Forward declaration of the CheckValidBoxIndexHelper specialization for GPU.
811 namespace functor {
812 template <>
813 void CheckValidBoxIndexHelper<GPUDevice>::operator()(
814     const GPUDevice& d, typename TTypes<int32, 1>::ConstTensor box_index,
815     int batch_size, typename TTypes<bool, 0>::Tensor isvalid);
816 extern template struct CheckValidBoxIndexHelper<GPUDevice>;
817 }  // namespace functor
818 
819 namespace {
820 
821 // Specialization of CheckValidBoxIndex for a GPUDevice.
822 template <>
RunIfBoxIndexIsValid(OpKernelContext * context,typename TTypes<int32,1>::ConstTensor box_index,int batch_size,const Callback & compute,const Callback & done)823 inline void RunIfBoxIndexIsValid<GPUDevice>(
824     OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
825     int batch_size, const Callback& compute, const Callback& done) {
826   const int num_boxes = box_index.dimension(0);
827   if (num_boxes == 0) {
828     compute();
829     done();
830     return;
831   }
832 
833   Tensor isvalid_dev_tensor;
834   OP_REQUIRES_OK_ASYNC(
835       context,
836       context->allocate_temp(DataTypeToEnum<bool>::value, TensorShape({}),
837                              &isvalid_dev_tensor),
838       done);
839   typename TTypes<bool, 0>::Tensor isvalid_dev =
840       isvalid_dev_tensor.tensor<bool, 0>();
841 
842   // Run the actual box check on the device.
843   functor::CheckValidBoxIndexHelper<GPUDevice>()(
844       context->eigen_device<GPUDevice>(), box_index, batch_size, isvalid_dev);
845 
846   // Copy the result back to the host.
847   auto* stream = context->op_device_context()->stream();
848   OP_REQUIRES_ASYNC(context, stream,
849                     errors::Internal("No GPU stream available."), done);
850   Tensor isvalid_host_tensor;
851   // Use pinned host memory on the host to avoid unnecessary
852   // synchronization.
853   AllocatorAttributes alloc_attr;
854   alloc_attr.set_on_host(true);
855   alloc_attr.set_gpu_compatible(true);
856   OP_REQUIRES_OK_ASYNC(
857       context,
858       context->allocate_temp(DataTypeToEnum<bool>::value, TensorShape({}),
859                              &isvalid_host_tensor, alloc_attr),
860       done);
861   se::DeviceMemoryBase wrapped(isvalid_dev.data(), sizeof(bool));
862   const bool status =
863       stream
864           ->ThenMemcpy(
865               isvalid_host_tensor.scalar<bool>().data() /* destination */,
866               wrapped /* source */, sizeof(bool))
867           .ok();
868   OP_REQUIRES_ASYNC(
869       context, status,
870       errors::Internal("Failed to launch copy of isvalid from device to host."),
871       done);
872 
873   // We capture both temporary tensors to prevent them from being deallocated
874   // when ComputeAsync returns and before the closure runs.
875   TensorReference isvalid_dev_ref(isvalid_dev_tensor);
876   auto wrapped_callback = [context, isvalid_host_tensor, isvalid_dev_ref,
877                            compute, done]() {
878     auto stream = context->op_device_context()->stream();
879     ScopedActivateExecutorContext scoped_activation{stream->parent()};
880     const bool isvalid = isvalid_host_tensor.scalar<bool>()();
881     isvalid_dev_ref.Unref();
882     OP_REQUIRES_ASYNC(
883         context, isvalid,
884         errors::OutOfRange("box_index has values outside [0, batch_size)"),
885         done);
886     compute();
887     done();
888   };
889 
890   context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
891       stream, wrapped_callback);
892 }
893 
894 }  // namespace
895 
896 #define REGISTER_KERNEL(T)                                         \
897   REGISTER_KERNEL_BUILDER(Name("CropAndResize")                    \
898                               .Device(DEVICE_GPU)                  \
899                               .TypeConstraint<T>("T")              \
900                               .HostMemory("crop_size"),            \
901                           CropAndResizeOp<GPUDevice, T>);          \
902                                                                    \
903   REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradImage")           \
904                               .Device(DEVICE_GPU)                  \
905                               .TypeConstraint<T>("T")              \
906                               .HostMemory("image_size"),           \
907                           CropAndResizeGradImageOp<GPUDevice, T>); \
908                                                                    \
909   REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes")           \
910                               .Device(DEVICE_GPU)                  \
911                               .TypeConstraint<T>("T"),             \
912                           CropAndResizeGradBoxesOp<GPUDevice, T>);
913 
914 TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNEL);
915 
916 #undef REGISTER_KERNEL
917 
918 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
919 
920 }  // namespace tensorflow
921