• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/types/span.h"
17 #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
18 #include "tensorflow/compiler/tf2xla/lib/util.h"
19 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
20 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
23 #include "tensorflow/compiler/xla/client/lib/constants.h"
24 #include "tensorflow/compiler/xla/client/lib/loops.h"
25 #include "tensorflow/compiler/xla/client/lib/sorting.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/core/framework/tensor_shape.h"
29 #include "tensorflow/core/framework/types.pb.h"
30 
31 namespace tensorflow {
32 namespace {
33 
34 // Converts 'input' from RGB format to HSV format.
35 // 'shape' is the shape of the red/green/blue tensors.
RGBToHSV(XlaOpKernelContext * ctx,xla::XlaBuilder * b,const std::array<xla::XlaOp,3> & rgb,DataType dtype,const TensorShape & shape)36 std::array<xla::XlaOp, 3> RGBToHSV(XlaOpKernelContext* ctx, xla::XlaBuilder* b,
37                                    const std::array<xla::XlaOp, 3>& rgb,
38                                    DataType dtype, const TensorShape& shape) {
39   auto zero = XlaHelpers::Zero(b, dtype);
40   auto one = XlaHelpers::One(b, dtype);
41 
42   auto red = rgb[0];
43   auto green = rgb[1];
44   auto blue = rgb[2];
45   auto value = xla::Max(xla::Max(red, green), blue);
46   auto minimum = xla::Min(xla::Min(red, green), blue);
47   auto range = xla::Sub(value, minimum);
48 
49   auto zeros = xla::Broadcast(zero, shape.dim_sizes());
50   auto saturation =
51       xla::Select(xla::Gt(value, zero), xla::Div(range, value), zeros);
52 
53   auto norm = xla::Div(XlaHelpers::FloatLiteral(b, dtype, 1.0 / 6.0), range);
54 
55   auto hue =
56       xla::Select(xla::Eq(green, value),
57                   xla::Add(xla::Mul(norm, xla::Sub(blue, red)),
58                            XlaHelpers::FloatLiteral(b, dtype, 2.0 / 6.0)),
59                   xla::Add(xla::Mul(norm, xla::Sub(red, green)),
60                            XlaHelpers::FloatLiteral(b, dtype, 4.0 / 6.0)));
61   hue = xla::Select(xla::Eq(red, value), xla::Mul(norm, xla::Sub(green, blue)),
62                     hue);
63   hue = xla::Select(xla::Gt(range, zero), hue, zeros);
64   hue = xla::Select(xla::Lt(hue, zero), xla::Add(hue, one), hue);
65   return {hue, saturation, value};
66 }
67 
68 // Converts 'input' from HSV format to RGB format.
HSVToRGB(xla::XlaBuilder * b,const std::array<xla::XlaOp,3> & hsv,DataType dtype)69 std::array<xla::XlaOp, 3> HSVToRGB(xla::XlaBuilder* b,
70                                    const std::array<xla::XlaOp, 3>& hsv,
71                                    DataType dtype) {
72   xla::XlaOp hue = hsv[0];
73   xla::XlaOp saturation = hsv[1];
74   xla::XlaOp value = hsv[2];
75   auto zero = XlaHelpers::Zero(b, dtype);
76   auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0);
77   auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
78   auto three = XlaHelpers::FloatLiteral(b, dtype, 3.0);
79   auto four = XlaHelpers::FloatLiteral(b, dtype, 4.0);
80   auto six = XlaHelpers::FloatLiteral(b, dtype, 6.0);
81 
82   auto dh = xla::Mul(hue, six);
83   auto dr = xla::Clamp(zero, xla::Sub(xla::Abs(xla::Sub(dh, three)), one), one);
84   auto dg = xla::Clamp(zero, xla::Sub(two, xla::Abs(xla::Sub(dh, two))), one);
85   auto db = xla::Clamp(zero, xla::Sub(two, xla::Abs(xla::Sub(dh, four))), one);
86   auto one_minus_s = xla::Sub(one, saturation);
87 
88   auto red = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, dr)), value);
89   auto green = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, dg)), value);
90   auto blue = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, db)), value);
91   return {red, green, blue};
92 }
93 
94 class RGBToHSVOp : public XlaOpKernel {
95  public:
RGBToHSVOp(OpKernelConstruction * context)96   explicit RGBToHSVOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
97 
Compile(XlaOpKernelContext * context)98   void Compile(XlaOpKernelContext* context) override {
99     const TensorShape input_shape = context->InputShape(0);
100     OP_REQUIRES(context, input_shape.dims() >= 1,
101                 errors::InvalidArgument("input must be at least 1D",
102                                         input_shape.DebugString()));
103     int channel_dim = input_shape.dims() - 1;
104     int64 channels = input_shape.dim_size(channel_dim);
105     OP_REQUIRES(
106         context, channels == 3,
107         errors::FailedPrecondition("input must have 3 channels but input has ",
108                                    channels, " channels."));
109 
110     xla::XlaBuilder* b = context->builder();
111     xla::XlaOp input = context->Input(0);
112 
113     xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0,
114                                      /*limit_index=*/1, /*stride=*/1,
115                                      /*dimno=*/channel_dim);
116     xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1,
117                                        /*limit_index=*/2, /*stride=*/1,
118                                        /*dimno=*/channel_dim);
119     xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2,
120                                       /*limit_index=*/3, /*stride=*/1,
121                                       /*dimno=*/channel_dim);
122     TensorShape channel_shape = input_shape;
123     channel_shape.set_dim(channel_dim, 1);
124     auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0),
125                         channel_shape);
126 
127     context->SetOutput(0, xla::ConcatInDim(b, hsv, channel_dim));
128   }
129 };
130 REGISTER_XLA_OP(Name("RGBToHSV"), RGBToHSVOp);
131 
132 class HSVToRGBOp : public XlaOpKernel {
133  public:
HSVToRGBOp(OpKernelConstruction * context)134   explicit HSVToRGBOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
135 
Compile(XlaOpKernelContext * context)136   void Compile(XlaOpKernelContext* context) override {
137     const TensorShape input_shape = context->InputShape(0);
138     OP_REQUIRES(context, input_shape.dims() >= 1,
139                 errors::InvalidArgument("input must be at least 1D",
140                                         input_shape.DebugString()));
141     int channel_dim = input_shape.dims() - 1;
142     int64 channels = input_shape.dim_size(channel_dim);
143     OP_REQUIRES(
144         context, channels == 3,
145         errors::FailedPrecondition("input must have 3 channels but input has ",
146                                    channels, " channels."));
147 
148     xla::XlaBuilder* b = context->builder();
149     xla::XlaOp input = context->Input(0);
150     xla::XlaOp hue = xla::SliceInDim(input, /*start_index=*/0,
151                                      /*limit_index=*/1, /*stride=*/1,
152                                      /*dimno=*/channel_dim);
153     xla::XlaOp saturation = xla::SliceInDim(input, /*start_index=*/1,
154                                             /*limit_index=*/2, /*stride=*/1,
155                                             /*dimno=*/channel_dim);
156     xla::XlaOp value = xla::SliceInDim(input, /*start_index=*/2,
157                                        /*limit_index=*/3, /*stride=*/1,
158                                        /*dimno=*/channel_dim);
159 
160     auto rgb = HSVToRGB(context->builder(), {hue, saturation, value},
161                         context->input_type(0));
162 
163     context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim));
164   }
165 };
166 REGISTER_XLA_OP(Name("HSVToRGB"), HSVToRGBOp);
167 
168 class AdjustContrastOpV2 : public XlaOpKernel {
169  public:
AdjustContrastOpV2(OpKernelConstruction * context)170   explicit AdjustContrastOpV2(OpKernelConstruction* context)
171       : XlaOpKernel(context) {}
172 
Compile(XlaOpKernelContext * context)173   void Compile(XlaOpKernelContext* context) override {
174     const TensorShape& input_shape = context->InputShape(0);
175     const TensorShape& factor_shape = context->InputShape(1);
176     OP_REQUIRES(context, input_shape.dims() >= 3,
177                 errors::InvalidArgument("input must be at least 3-D, got shape",
178                                         input_shape.DebugString()));
179     int height_dim = input_shape.dims() - 3;
180     int width_dim = input_shape.dims() - 2;
181     int channel_dim = input_shape.dims() - 1;
182     const int64 height = input_shape.dim_size(height_dim);
183     const int64 width = input_shape.dim_size(width_dim);
184 
185     OP_REQUIRES(context, TensorShapeUtils::IsScalar(factor_shape),
186                 errors::InvalidArgument("contrast_factor must be scalar: ",
187                                         factor_shape.DebugString()));
188 
189     xla::XlaBuilder* b = context->builder();
190     DataType type = context->input_type(0);
191 
192     xla::XlaOp input = context->Input(0);
193     xla::XlaOp factor = XlaHelpers::ConvertElementType(context->Input(1), type);
194 
195     const DataType accumulation_type = XlaHelpers::SumAccumulationType(type);
196     auto converted = XlaHelpers::ConvertElementType(input, accumulation_type);
197     auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
198                               *context->GetOrCreateAdd(accumulation_type),
199                               {height_dim, width_dim});
200 
201     auto output = xla::Div(
202         reduce, XlaHelpers::FloatLiteral(b, accumulation_type, height * width));
203     output = XlaHelpers::ConvertElementType(output, type);
204 
205     std::vector<int64> broadcast_dims(input_shape.dims() - 2);
206     std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
207     broadcast_dims.back() = channel_dim;
208     output =
209         xla::Add(xla::Mul(input, factor),
210                  xla::Mul(output, xla::Sub(XlaHelpers::One(b, type), factor)),
211                  broadcast_dims);
212     context->SetOutput(0, output);
213   }
214 };
215 REGISTER_XLA_OP(Name("AdjustContrastv2"), AdjustContrastOpV2);
216 
217 class AdjustSaturationOp : public XlaOpKernel {
218  public:
AdjustSaturationOp(OpKernelConstruction * context)219   explicit AdjustSaturationOp(OpKernelConstruction* context)
220       : XlaOpKernel(context) {}
221 
Compile(XlaOpKernelContext * context)222   void Compile(XlaOpKernelContext* context) override {
223     const TensorShape& input_shape = context->InputShape(0);
224     const TensorShape& scale_shape = context->InputShape(1);
225     OP_REQUIRES(context, input_shape.dims() >= 3,
226                 errors::InvalidArgument("input must be at least 3-D, got shape",
227                                         input_shape.DebugString()));
228     OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_shape),
229                 errors::InvalidArgument("scale must be scalar: ",
230                                         scale_shape.DebugString()));
231     const int channel_dim = input_shape.dims() - 1;
232     const int64 channels = input_shape.dim_size(channel_dim);
233     OP_REQUIRES(
234         context, channels == 3,
235         errors::InvalidArgument("input must have 3 channels but instead has ",
236                                 channels, " channels."));
237 
238     xla::XlaBuilder* b = context->builder();
239     xla::XlaOp input =
240         XlaHelpers::ConvertElementType(context->Input(0), DT_FLOAT);
241     xla::XlaOp scale =
242         XlaHelpers::ConvertElementType(context->Input(1), DT_FLOAT);
243 
244     DataType type = context->input_type(0);
245 
246     xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0,
247                                      /*limit_index=*/1, /*stride=*/1,
248                                      /*dimno=*/channel_dim);
249     xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1,
250                                        /*limit_index=*/2, /*stride=*/1,
251                                        /*dimno=*/channel_dim);
252     xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2,
253                                       /*limit_index=*/3, /*stride=*/1,
254                                       /*dimno=*/channel_dim);
255     TensorShape channel_shape = input_shape;
256     channel_shape.set_dim(channel_dim, 1);
257     auto hsv =
258         RGBToHSV(context, b, {red, green, blue}, DT_FLOAT, channel_shape);
259 
260     hsv[1] = xla::Clamp(XlaHelpers::Zero(b, DT_FLOAT), xla::Mul(hsv[1], scale),
261                         XlaHelpers::One(b, DT_FLOAT));
262 
263     auto rgb = HSVToRGB(context->builder(), hsv, DT_FLOAT);
264 
265     auto output = XlaHelpers::ConvertElementType(
266         xla::ConcatInDim(b, rgb, channel_dim), type);
267     context->SetOutput(0, output);
268   }
269 };
270 REGISTER_XLA_OP(Name("AdjustSaturation"), AdjustSaturationOp);
271 
272 class AdjustHueOp : public XlaOpKernel {
273  public:
AdjustHueOp(OpKernelConstruction * context)274   explicit AdjustHueOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
275 
Compile(XlaOpKernelContext * context)276   void Compile(XlaOpKernelContext* context) override {
277     const TensorShape& input_shape = context->InputShape(0);
278     const TensorShape& delta_shape = context->InputShape(1);
279     OP_REQUIRES(context, input_shape.dims() >= 3,
280                 errors::InvalidArgument("input must be at least 3-D, got shape",
281                                         input_shape.DebugString()));
282     OP_REQUIRES(context, TensorShapeUtils::IsScalar(delta_shape),
283                 errors::InvalidArgument("delta must be scalar: ",
284                                         delta_shape.DebugString()));
285     const int channel_dim = input_shape.dims() - 1;
286     const int64 channels = input_shape.dim_size(channel_dim);
287     OP_REQUIRES(
288         context, channels == 3,
289         errors::InvalidArgument("input must have 3 channels but instead has ",
290                                 channels, " channels."));
291 
292     xla::XlaBuilder* b = context->builder();
293     xla::XlaOp input =
294         XlaHelpers::ConvertElementType(context->Input(0), DT_FLOAT);
295     xla::XlaOp delta =
296         XlaHelpers::ConvertElementType(context->Input(1), DT_FLOAT);
297 
298     DataType type = context->input_type(0);
299 
300     xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0,
301                                      /*limit_index=*/1, /*stride=*/1,
302                                      /*dimno=*/channel_dim);
303     xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1,
304                                        /*limit_index=*/2, /*stride=*/1,
305                                        /*dimno=*/channel_dim);
306     xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2,
307                                       /*limit_index=*/3, /*stride=*/1,
308                                       /*dimno=*/channel_dim);
309     TensorShape channel_shape = input_shape;
310     channel_shape.set_dim(channel_dim, 1);
311     auto hsv =
312         RGBToHSV(context, b, {red, green, blue}, DT_FLOAT, channel_shape);
313 
314     auto zero = XlaHelpers::Zero(b, DT_FLOAT);
315     auto one = XlaHelpers::One(b, DT_FLOAT);
316 
317     auto& hue = hsv[0];
318     hue = xla::Rem(xla::Add(hsv[0], delta), one);
319     hue =
320         xla::Select(xla::Lt(hue, zero), xla::Rem(xla::Add(one, hue), one), hue);
321 
322     auto rgb = HSVToRGB(context->builder(), hsv, DT_FLOAT);
323 
324     auto output = XlaHelpers::ConvertElementType(
325         xla::ConcatInDim(b, rgb, channel_dim), type);
326     context->SetOutput(0, output);
327   }
328 };
329 REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp);
330 
331 struct WhileCondFn {
332   const int64 num_boxes;
333   const int64 output_size;
334 
WhileCondFntensorflow::__anon4034b0330111::WhileCondFn335   explicit WhileCondFn(int64 num_boxes, int64 output_size)
336       : num_boxes(num_boxes), output_size(output_size) {}
337 
operator ()tensorflow::__anon4034b0330111::WhileCondFn338   xla::StatusOr<xla::XlaOp> operator()(absl::Span<const xla::XlaOp> values,
339                                        xla::XlaBuilder* cond_builder) const {
340     xla::XlaOp row_idx = values[0];
341     xla::XlaOp row_in_bounds =
342         xla::Lt(row_idx, xla::ConstantR0<int32>(cond_builder, num_boxes));
343     xla::XlaOp num_outputs_so_far = values[1];
344     xla::XlaOp results_not_full = xla::Lt(
345         num_outputs_so_far, xla::ConstantR0<int32>(cond_builder, output_size));
346     return xla::And(row_in_bounds, results_not_full);
347   }
348 };
349 
350 // Process the boxes one-by-one using the iou matrix mask.
351 // This implementation uses a correct, but greedy, sequential algorithm
352 // to ensure that suppressed boxes cannot themselves suppress other
353 // boxes.
354 struct SuppressBodyFn {
355   const int64 num_boxes;
356 
SuppressBodyFntensorflow::__anon4034b0330111::SuppressBodyFn357   explicit SuppressBodyFn(int64 num_boxes) : num_boxes(num_boxes) {}
358 
operator ()tensorflow::__anon4034b0330111::SuppressBodyFn359   xla::StatusOr<std::vector<xla::XlaOp>> operator()(
360       absl::Span<const xla::XlaOp> values, xla::XlaBuilder* builder) const {
361     auto row_idx = values[0];
362     auto num_outputs_so_far = values[1];
363     auto iou_mask = values[2];
364     auto included_iou = values[3];
365     auto zero = xla::ConstantR0<int32>(builder, 0);
366     // Determine if current elem is active using a slice.
367     // TODO(b/118437727): The only reason we need an explicit vector is because
368     // some old GCCs can't deduce the right type for MakeConstSpan, and
369     // providing a single-value initializer list directly uses the wrong
370     // overload. Delete this once the deprecated overload is gone.
371     std::vector<xla::XlaOp> row_idx_vector = {row_idx};
372     auto active_elem = xla::DynamicSlice(included_iou, row_idx_vector, {1});
373     active_elem = xla::Reshape(active_elem, {});
374     // Increment output count iff current elem is not suppressed.
375     num_outputs_so_far = xla::Select(
376         active_elem, num_outputs_so_far + xla::ConstantR0<int32>(builder, 1),
377         num_outputs_so_far);
378     // Slice out the row_idx.
379     auto row_iou = xla::DynamicSlice(iou_mask, {row_idx, zero}, {1, num_boxes});
380     // Remove the diagonal from consideration. An elem cannot suppress
381     // itself.
382     row_iou = xla::DynamicUpdateSlice(
383         row_iou, xla::ConstantR2FromArray2D<bool>(builder, {{false}}),
384         {zero, row_idx});
385     // Create a suppression by inverting polarity.
386     row_iou = xla::Reshape(row_iou, {num_boxes});
387     auto supp_mask = xla::Not(row_iou);
388     // Update mask iff current elem is not suppressed.
389     included_iou = xla::Select(xla::Broadcast(active_elem, {num_boxes}),
390                                xla::And(included_iou, supp_mask), included_iou);
391     row_idx = row_idx + xla::ConstantR0<int32>(builder, 1);
392     return std::vector<xla::XlaOp>{row_idx, num_outputs_so_far, iou_mask,
393                                    included_iou};
394   }
395 };
396 
397 class NonMaxSuppressionOp : public XlaOpKernel {
398  public:
NonMaxSuppressionOp(OpKernelConstruction * context)399   explicit NonMaxSuppressionOp(OpKernelConstruction* context)
400       : XlaOpKernel(context) {
401     OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size",
402                                              &pad_to_max_output_size_));
403   }
404 
Compile(XlaOpKernelContext * context)405   void Compile(XlaOpKernelContext* context) override {
406     // TODO(b/111646731): Improve scalability of this op, using blocking.
407     const TensorShape& boxes_shape = context->InputShape("boxes");
408     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(boxes_shape),
409                 errors::InvalidArgument("boxes must be 2-D, currently: ",
410                                         boxes_shape.DebugString()));
411     const int64 num_boxes = boxes_shape.dim_size(0);
412     OP_REQUIRES(context, boxes_shape.dim_size(1) == 4,
413                 errors::InvalidArgument("boxes must have 4 columns",
414                                         boxes_shape.DebugString()));
415     const TensorShape& scores_shape = context->InputShape("scores");
416     OP_REQUIRES(context, TensorShapeUtils::IsVector(scores_shape),
417                 errors::InvalidArgument("scores must be 1-D, currently: ",
418                                         scores_shape.DebugString()));
419     OP_REQUIRES(
420         context, scores_shape.dim_size(0) == num_boxes,
421         errors::InvalidArgument("scores size must equal number of boxes",
422                                 scores_shape.DebugString()));
423     OP_REQUIRES(context, pad_to_max_output_size_,
424                 errors::InvalidArgument(
425                     "XLA compilation requires pad_to_max_output_size == True"));
426     OP_REQUIRES(context, num_boxes <= kint32max,
427                 errors::InvalidArgument("XLA compilation requires number of "
428                                         "boxes to be <= kint32max, got ",
429                                         num_boxes));
430 
431     const xla::XlaOp boxes_input = context->Input("boxes");
432     const xla::XlaOp scores_input = context->Input("scores");
433     int64 output_size;
434     OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &output_size));
435     OP_REQUIRES(
436         context, output_size >= 0,
437         errors::InvalidArgument("Need output_size >= 0, got ", output_size));
438     OP_REQUIRES(context, output_size <= kint32max,
439                 errors::InvalidArgument("Need output_size <= kint32Max, got ",
440                                         output_size));
441     const xla::XlaOp score_thresh = context->Input("score_threshold");
442     const xla::XlaOp iou_thresh = context->Input("iou_threshold");
443     xla::XlaBuilder* const builder = context->builder();
444 
445     // Choose a more convenient layout.
446     const xla::XlaOp boxes = xla::Transpose(boxes_input, {1, 0});
447     const xla::XlaOp boxes_sorted = xla::GetTupleElement(
448         xla::Sort(/*keys=*/-xla::Broadcast(scores_input, {4}),
449                   /*values=*/{boxes},
450                   /*dimension=*/1),
451         1);
452     // Track the mapping of indices into sorted domain.
453     const xla::XlaOp iota_indices = xla::Iota(builder, xla::S32, num_boxes);
454     const xla::XlaOp indices_sort = xla::Sort(-scores_input, {iota_indices});
455     const xla::XlaOp indices_sorted = xla::GetTupleElement(indices_sort, 1);
456     const xla::XlaOp scores = xla::Neg(xla::GetTupleElement(indices_sort, 0));
457 
458     // Shapes are henceforth [1, num_boxes]. 'c_y0' denotes 'coordinate' y0.
459     const xla::XlaOp c_y0 = xla::Reshape(xla::SliceInDim(boxes_sorted,
460                                                          /*start_index=*/0,
461                                                          /*limit_index=*/1,
462                                                          /*stride=*/1,
463                                                          /*dimno=*/0),
464                                          {num_boxes});
465     const xla::XlaOp c_x0 = xla::Reshape(xla::SliceInDim(boxes_sorted,
466                                                          /*start_index=*/1,
467                                                          /*limit_index=*/2,
468                                                          /*stride=*/1,
469                                                          /*dimno=*/0),
470                                          {num_boxes});
471     const xla::XlaOp c_y1 = xla::Reshape(xla::SliceInDim(boxes_sorted,
472                                                          /*start_index=*/2,
473                                                          /*limit_index=*/3,
474                                                          /*stride=*/1,
475                                                          /*dimno=*/0),
476                                          {num_boxes});
477     const xla::XlaOp c_x1 = xla::Reshape(xla::SliceInDim(boxes_sorted,
478                                                          /*start_index=*/3,
479                                                          /*limit_index=*/4,
480                                                          /*stride=*/1,
481                                                          /*dimno=*/0),
482                                          {num_boxes});
483 
484     xla::XlaOp y1 = xla::Select(xla::Le(c_y0, c_y1), c_y0, c_y1);
485     xla::XlaOp y2 = xla::Select(xla::Le(c_y0, c_y1), c_y1, c_y0);
486     xla::XlaOp x1 = xla::Select(xla::Le(c_x0, c_x1), c_x0, c_x1);
487     xla::XlaOp x2 = xla::Select(xla::Le(c_x0, c_x1), c_x1, c_x0);
488     xla::XlaOp area = (y2 - y1) * (x2 - x1);
489 
490     // Shapes are henceforth [1, num_boxes].
491     y1 = xla::Broadcast(y1, {1});
492     y2 = xla::Broadcast(y2, {1});
493     x1 = xla::Broadcast(x1, {1});
494     x2 = xla::Broadcast(x2, {1});
495     area = xla::Broadcast(area, {1});
496 
497     // Shapes are henceforth [num_boxes, num_boxes].
498     xla::XlaOp i_xmin = xla::Max(x1, xla::Transpose(x1, {1, 0}));
499     xla::XlaOp i_ymin = xla::Max(y1, xla::Transpose(y1, {1, 0}));
500     xla::XlaOp i_xmax = xla::Min(x2, xla::Transpose(x2, {1, 0}));
501     xla::XlaOp i_ymax = xla::Min(y2, xla::Transpose(y2, {1, 0}));
502     auto square_zero = xla::ZerosLike(i_xmin);
503 
504     xla::XlaOp i_area = xla::Max(i_xmax - i_xmin, square_zero) *
505                         xla::Max(i_ymax - i_ymin, square_zero);
506     xla::XlaOp u_area = area + xla::Transpose(area, {1, 0}) - i_area;
507     xla::XlaOp iou = i_area / u_area;
508 
509     xla::XlaOp iou_thresh_mask = xla::Gt(iou, iou_thresh + square_zero);
510     xla::XlaOp included_iou =
511         xla::Broadcast(xla::ConstantR0<bool>(builder, true), {num_boxes});
512 
513     std::vector<xla::XlaOp> init_values;
514     init_values.reserve(4);
515     init_values.push_back(xla::ConstantR0<int32>(builder, 0));  // col_idx
516     init_values.push_back(xla::ConstantR0<int32>(builder, 0));  // num_outputs
517     init_values.push_back(iou_thresh_mask);
518     init_values.push_back(included_iou);
519 
520     auto suppress_loop_result =
521         xla::WhileLoopHelper(WhileCondFn(num_boxes, output_size),
522                              SuppressBodyFn(num_boxes), init_values,
523                              "suppress_loop", builder)
524             .ValueOrDie();
525 
526     xla::XlaOp included_score =
527         xla::Gt(scores, xla::Broadcast(score_thresh, {num_boxes}));
528     xla::XlaOp included = xla::And(included_score, suppress_loop_result[3]);
529 
530     // Only consider boxes over which we have iterated. This allows for accurate
531     // counting. DynamicSlice would require knowledge of the size of the output.
532     auto valid_elem = xla::Lt(
533         iota_indices, xla::Broadcast(suppress_loop_result[0], {num_boxes}));
534     included = xla::And(included, valid_elem);
535 
536     xla::XlaOp neg_inf =
537         xla::Broadcast(xla::MinValue(builder, xla::F32), {num_boxes});
538     xla::XlaOp scores_included = xla::Select(included, scores, neg_inf);
539     xla::XlaOp output_tuple = TopK(scores_included, output_size);
540     xla::XlaOp selected_indices_sorted = xla::GetTupleElement(output_tuple, 1);
541     // Calculate num_valid.
542     // Note: num_valid cannot be taken from the loop outputs, because outputs
543     // can be suppressed by score threshold.
544     xla::XlaOp ones_included = xla::Select(
545         included,
546         xla::Broadcast(xla::ConstantR0<int32>(builder, 1), {num_boxes}),
547         xla::Broadcast(xla::ConstantR0<int32>(builder, 0), {num_boxes}));
548     // num_valid is scalar. Value should be bound by output_size.
549     xla::XlaOp num_valid_total = xla::Reduce(
550         ones_included,
551         /*init_value=*/xla::ConstantR0<int>(builder, 0),
552         /*computation=*/CreateScalarAddComputation(xla::S32, builder),
553         /*dimensions_to_reduce=*/{0});
554     xla::XlaOp num_valid =
555         xla::Min(num_valid_total, xla::ConstantR0<int32>(builder, output_size));
556 
557     // Re-index into the original scores input tensor, using a Gather.
558     // Boxes were suppressed in the sorted domain.
559     xla::XlaOp selected_indices;
560     DataType gather_type = context->expected_output_dtype(0);
561     OP_REQUIRES_OK(
562         context,
563         XlaGather(indices_sorted, scores_shape, selected_indices_sorted,
564                   TensorShape({output_size}),
565                   /*axis=*/0,
566                   /*indices_are_nd=*/false,
567                   /*dtype=*/gather_type, DT_INT32, builder, &selected_indices));
568 
569     context->SetOutput(0, selected_indices);
570     context->SetOutput(1, num_valid);
571   }
572 
573  private:
574   bool pad_to_max_output_size_;
575 };
576 
577 REGISTER_XLA_OP(
578     Name("NonMaxSuppressionV4").CompileTimeConstantInput("max_output_size"),
579     NonMaxSuppressionOp);
580 
581 }  // namespace
582 }  // namespace tensorflow
583