1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "absl/types/span.h"
17 #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
18 #include "tensorflow/compiler/tf2xla/lib/util.h"
19 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
20 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
23 #include "tensorflow/compiler/xla/client/lib/constants.h"
24 #include "tensorflow/compiler/xla/client/lib/loops.h"
25 #include "tensorflow/compiler/xla/client/lib/sorting.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/core/framework/tensor_shape.h"
29 #include "tensorflow/core/framework/types.pb.h"
30
31 namespace tensorflow {
32 namespace {
33
34 // Converts 'input' from RGB format to HSV format.
35 // 'shape' is the shape of the red/green/blue tensors.
RGBToHSV(XlaOpKernelContext * ctx,xla::XlaBuilder * b,const std::array<xla::XlaOp,3> & rgb,DataType dtype,const TensorShape & shape)36 std::array<xla::XlaOp, 3> RGBToHSV(XlaOpKernelContext* ctx, xla::XlaBuilder* b,
37 const std::array<xla::XlaOp, 3>& rgb,
38 DataType dtype, const TensorShape& shape) {
39 auto zero = XlaHelpers::Zero(b, dtype);
40 auto one = XlaHelpers::One(b, dtype);
41
42 auto red = rgb[0];
43 auto green = rgb[1];
44 auto blue = rgb[2];
45 auto value = xla::Max(xla::Max(red, green), blue);
46 auto minimum = xla::Min(xla::Min(red, green), blue);
47 auto range = xla::Sub(value, minimum);
48
49 auto zeros = xla::Broadcast(zero, shape.dim_sizes());
50 auto saturation =
51 xla::Select(xla::Gt(value, zero), xla::Div(range, value), zeros);
52
53 auto norm = xla::Div(XlaHelpers::FloatLiteral(b, dtype, 1.0 / 6.0), range);
54
55 auto hue =
56 xla::Select(xla::Eq(green, value),
57 xla::Add(xla::Mul(norm, xla::Sub(blue, red)),
58 XlaHelpers::FloatLiteral(b, dtype, 2.0 / 6.0)),
59 xla::Add(xla::Mul(norm, xla::Sub(red, green)),
60 XlaHelpers::FloatLiteral(b, dtype, 4.0 / 6.0)));
61 hue = xla::Select(xla::Eq(red, value), xla::Mul(norm, xla::Sub(green, blue)),
62 hue);
63 hue = xla::Select(xla::Gt(range, zero), hue, zeros);
64 hue = xla::Select(xla::Lt(hue, zero), xla::Add(hue, one), hue);
65 return {hue, saturation, value};
66 }
67
68 // Converts 'input' from HSV format to RGB format.
HSVToRGB(xla::XlaBuilder * b,const std::array<xla::XlaOp,3> & hsv,DataType dtype)69 std::array<xla::XlaOp, 3> HSVToRGB(xla::XlaBuilder* b,
70 const std::array<xla::XlaOp, 3>& hsv,
71 DataType dtype) {
72 xla::XlaOp hue = hsv[0];
73 xla::XlaOp saturation = hsv[1];
74 xla::XlaOp value = hsv[2];
75 auto zero = XlaHelpers::Zero(b, dtype);
76 auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0);
77 auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
78 auto three = XlaHelpers::FloatLiteral(b, dtype, 3.0);
79 auto four = XlaHelpers::FloatLiteral(b, dtype, 4.0);
80 auto six = XlaHelpers::FloatLiteral(b, dtype, 6.0);
81
82 auto dh = xla::Mul(hue, six);
83 auto dr = xla::Clamp(zero, xla::Sub(xla::Abs(xla::Sub(dh, three)), one), one);
84 auto dg = xla::Clamp(zero, xla::Sub(two, xla::Abs(xla::Sub(dh, two))), one);
85 auto db = xla::Clamp(zero, xla::Sub(two, xla::Abs(xla::Sub(dh, four))), one);
86 auto one_minus_s = xla::Sub(one, saturation);
87
88 auto red = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, dr)), value);
89 auto green = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, dg)), value);
90 auto blue = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, db)), value);
91 return {red, green, blue};
92 }
93
94 class RGBToHSVOp : public XlaOpKernel {
95 public:
RGBToHSVOp(OpKernelConstruction * context)96 explicit RGBToHSVOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
97
Compile(XlaOpKernelContext * context)98 void Compile(XlaOpKernelContext* context) override {
99 const TensorShape input_shape = context->InputShape(0);
100 OP_REQUIRES(context, input_shape.dims() >= 1,
101 errors::InvalidArgument("input must be at least 1D",
102 input_shape.DebugString()));
103 int channel_dim = input_shape.dims() - 1;
104 int64 channels = input_shape.dim_size(channel_dim);
105 OP_REQUIRES(
106 context, channels == 3,
107 errors::FailedPrecondition("input must have 3 channels but input has ",
108 channels, " channels."));
109
110 xla::XlaBuilder* b = context->builder();
111 xla::XlaOp input = context->Input(0);
112
113 xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0,
114 /*limit_index=*/1, /*stride=*/1,
115 /*dimno=*/channel_dim);
116 xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1,
117 /*limit_index=*/2, /*stride=*/1,
118 /*dimno=*/channel_dim);
119 xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2,
120 /*limit_index=*/3, /*stride=*/1,
121 /*dimno=*/channel_dim);
122 TensorShape channel_shape = input_shape;
123 channel_shape.set_dim(channel_dim, 1);
124 auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0),
125 channel_shape);
126
127 context->SetOutput(0, xla::ConcatInDim(b, hsv, channel_dim));
128 }
129 };
130 REGISTER_XLA_OP(Name("RGBToHSV"), RGBToHSVOp);
131
132 class HSVToRGBOp : public XlaOpKernel {
133 public:
HSVToRGBOp(OpKernelConstruction * context)134 explicit HSVToRGBOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
135
Compile(XlaOpKernelContext * context)136 void Compile(XlaOpKernelContext* context) override {
137 const TensorShape input_shape = context->InputShape(0);
138 OP_REQUIRES(context, input_shape.dims() >= 1,
139 errors::InvalidArgument("input must be at least 1D",
140 input_shape.DebugString()));
141 int channel_dim = input_shape.dims() - 1;
142 int64 channels = input_shape.dim_size(channel_dim);
143 OP_REQUIRES(
144 context, channels == 3,
145 errors::FailedPrecondition("input must have 3 channels but input has ",
146 channels, " channels."));
147
148 xla::XlaBuilder* b = context->builder();
149 xla::XlaOp input = context->Input(0);
150 xla::XlaOp hue = xla::SliceInDim(input, /*start_index=*/0,
151 /*limit_index=*/1, /*stride=*/1,
152 /*dimno=*/channel_dim);
153 xla::XlaOp saturation = xla::SliceInDim(input, /*start_index=*/1,
154 /*limit_index=*/2, /*stride=*/1,
155 /*dimno=*/channel_dim);
156 xla::XlaOp value = xla::SliceInDim(input, /*start_index=*/2,
157 /*limit_index=*/3, /*stride=*/1,
158 /*dimno=*/channel_dim);
159
160 auto rgb = HSVToRGB(context->builder(), {hue, saturation, value},
161 context->input_type(0));
162
163 context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim));
164 }
165 };
166 REGISTER_XLA_OP(Name("HSVToRGB"), HSVToRGBOp);
167
168 class AdjustContrastOpV2 : public XlaOpKernel {
169 public:
AdjustContrastOpV2(OpKernelConstruction * context)170 explicit AdjustContrastOpV2(OpKernelConstruction* context)
171 : XlaOpKernel(context) {}
172
Compile(XlaOpKernelContext * context)173 void Compile(XlaOpKernelContext* context) override {
174 const TensorShape& input_shape = context->InputShape(0);
175 const TensorShape& factor_shape = context->InputShape(1);
176 OP_REQUIRES(context, input_shape.dims() >= 3,
177 errors::InvalidArgument("input must be at least 3-D, got shape",
178 input_shape.DebugString()));
179 int height_dim = input_shape.dims() - 3;
180 int width_dim = input_shape.dims() - 2;
181 int channel_dim = input_shape.dims() - 1;
182 const int64 height = input_shape.dim_size(height_dim);
183 const int64 width = input_shape.dim_size(width_dim);
184
185 OP_REQUIRES(context, TensorShapeUtils::IsScalar(factor_shape),
186 errors::InvalidArgument("contrast_factor must be scalar: ",
187 factor_shape.DebugString()));
188
189 xla::XlaBuilder* b = context->builder();
190 DataType type = context->input_type(0);
191
192 xla::XlaOp input = context->Input(0);
193 xla::XlaOp factor = XlaHelpers::ConvertElementType(context->Input(1), type);
194
195 const DataType accumulation_type = XlaHelpers::SumAccumulationType(type);
196 auto converted = XlaHelpers::ConvertElementType(input, accumulation_type);
197 auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
198 *context->GetOrCreateAdd(accumulation_type),
199 {height_dim, width_dim});
200
201 auto output = xla::Div(
202 reduce, XlaHelpers::FloatLiteral(b, accumulation_type, height * width));
203 output = XlaHelpers::ConvertElementType(output, type);
204
205 std::vector<int64> broadcast_dims(input_shape.dims() - 2);
206 std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
207 broadcast_dims.back() = channel_dim;
208 output =
209 xla::Add(xla::Mul(input, factor),
210 xla::Mul(output, xla::Sub(XlaHelpers::One(b, type), factor)),
211 broadcast_dims);
212 context->SetOutput(0, output);
213 }
214 };
215 REGISTER_XLA_OP(Name("AdjustContrastv2"), AdjustContrastOpV2);
216
217 class AdjustSaturationOp : public XlaOpKernel {
218 public:
AdjustSaturationOp(OpKernelConstruction * context)219 explicit AdjustSaturationOp(OpKernelConstruction* context)
220 : XlaOpKernel(context) {}
221
Compile(XlaOpKernelContext * context)222 void Compile(XlaOpKernelContext* context) override {
223 const TensorShape& input_shape = context->InputShape(0);
224 const TensorShape& scale_shape = context->InputShape(1);
225 OP_REQUIRES(context, input_shape.dims() >= 3,
226 errors::InvalidArgument("input must be at least 3-D, got shape",
227 input_shape.DebugString()));
228 OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_shape),
229 errors::InvalidArgument("scale must be scalar: ",
230 scale_shape.DebugString()));
231 const int channel_dim = input_shape.dims() - 1;
232 const int64 channels = input_shape.dim_size(channel_dim);
233 OP_REQUIRES(
234 context, channels == 3,
235 errors::InvalidArgument("input must have 3 channels but instead has ",
236 channels, " channels."));
237
238 xla::XlaBuilder* b = context->builder();
239 xla::XlaOp input =
240 XlaHelpers::ConvertElementType(context->Input(0), DT_FLOAT);
241 xla::XlaOp scale =
242 XlaHelpers::ConvertElementType(context->Input(1), DT_FLOAT);
243
244 DataType type = context->input_type(0);
245
246 xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0,
247 /*limit_index=*/1, /*stride=*/1,
248 /*dimno=*/channel_dim);
249 xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1,
250 /*limit_index=*/2, /*stride=*/1,
251 /*dimno=*/channel_dim);
252 xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2,
253 /*limit_index=*/3, /*stride=*/1,
254 /*dimno=*/channel_dim);
255 TensorShape channel_shape = input_shape;
256 channel_shape.set_dim(channel_dim, 1);
257 auto hsv =
258 RGBToHSV(context, b, {red, green, blue}, DT_FLOAT, channel_shape);
259
260 hsv[1] = xla::Clamp(XlaHelpers::Zero(b, DT_FLOAT), xla::Mul(hsv[1], scale),
261 XlaHelpers::One(b, DT_FLOAT));
262
263 auto rgb = HSVToRGB(context->builder(), hsv, DT_FLOAT);
264
265 auto output = XlaHelpers::ConvertElementType(
266 xla::ConcatInDim(b, rgb, channel_dim), type);
267 context->SetOutput(0, output);
268 }
269 };
270 REGISTER_XLA_OP(Name("AdjustSaturation"), AdjustSaturationOp);
271
272 class AdjustHueOp : public XlaOpKernel {
273 public:
AdjustHueOp(OpKernelConstruction * context)274 explicit AdjustHueOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
275
Compile(XlaOpKernelContext * context)276 void Compile(XlaOpKernelContext* context) override {
277 const TensorShape& input_shape = context->InputShape(0);
278 const TensorShape& delta_shape = context->InputShape(1);
279 OP_REQUIRES(context, input_shape.dims() >= 3,
280 errors::InvalidArgument("input must be at least 3-D, got shape",
281 input_shape.DebugString()));
282 OP_REQUIRES(context, TensorShapeUtils::IsScalar(delta_shape),
283 errors::InvalidArgument("delta must be scalar: ",
284 delta_shape.DebugString()));
285 const int channel_dim = input_shape.dims() - 1;
286 const int64 channels = input_shape.dim_size(channel_dim);
287 OP_REQUIRES(
288 context, channels == 3,
289 errors::InvalidArgument("input must have 3 channels but instead has ",
290 channels, " channels."));
291
292 xla::XlaBuilder* b = context->builder();
293 xla::XlaOp input =
294 XlaHelpers::ConvertElementType(context->Input(0), DT_FLOAT);
295 xla::XlaOp delta =
296 XlaHelpers::ConvertElementType(context->Input(1), DT_FLOAT);
297
298 DataType type = context->input_type(0);
299
300 xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0,
301 /*limit_index=*/1, /*stride=*/1,
302 /*dimno=*/channel_dim);
303 xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1,
304 /*limit_index=*/2, /*stride=*/1,
305 /*dimno=*/channel_dim);
306 xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2,
307 /*limit_index=*/3, /*stride=*/1,
308 /*dimno=*/channel_dim);
309 TensorShape channel_shape = input_shape;
310 channel_shape.set_dim(channel_dim, 1);
311 auto hsv =
312 RGBToHSV(context, b, {red, green, blue}, DT_FLOAT, channel_shape);
313
314 auto zero = XlaHelpers::Zero(b, DT_FLOAT);
315 auto one = XlaHelpers::One(b, DT_FLOAT);
316
317 auto& hue = hsv[0];
318 hue = xla::Rem(xla::Add(hsv[0], delta), one);
319 hue =
320 xla::Select(xla::Lt(hue, zero), xla::Rem(xla::Add(one, hue), one), hue);
321
322 auto rgb = HSVToRGB(context->builder(), hsv, DT_FLOAT);
323
324 auto output = XlaHelpers::ConvertElementType(
325 xla::ConcatInDim(b, rgb, channel_dim), type);
326 context->SetOutput(0, output);
327 }
328 };
329 REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp);
330
331 struct WhileCondFn {
332 const int64 num_boxes;
333 const int64 output_size;
334
WhileCondFntensorflow::__anon4034b0330111::WhileCondFn335 explicit WhileCondFn(int64 num_boxes, int64 output_size)
336 : num_boxes(num_boxes), output_size(output_size) {}
337
operator ()tensorflow::__anon4034b0330111::WhileCondFn338 xla::StatusOr<xla::XlaOp> operator()(absl::Span<const xla::XlaOp> values,
339 xla::XlaBuilder* cond_builder) const {
340 xla::XlaOp row_idx = values[0];
341 xla::XlaOp row_in_bounds =
342 xla::Lt(row_idx, xla::ConstantR0<int32>(cond_builder, num_boxes));
343 xla::XlaOp num_outputs_so_far = values[1];
344 xla::XlaOp results_not_full = xla::Lt(
345 num_outputs_so_far, xla::ConstantR0<int32>(cond_builder, output_size));
346 return xla::And(row_in_bounds, results_not_full);
347 }
348 };
349
350 // Process the boxes one-by-one using the iou matrix mask.
351 // This implementation uses a correct, but greedy, sequential algorithm
352 // to ensure that suppressed boxes cannot themselves suppress other
353 // boxes.
354 struct SuppressBodyFn {
355 const int64 num_boxes;
356
SuppressBodyFntensorflow::__anon4034b0330111::SuppressBodyFn357 explicit SuppressBodyFn(int64 num_boxes) : num_boxes(num_boxes) {}
358
operator ()tensorflow::__anon4034b0330111::SuppressBodyFn359 xla::StatusOr<std::vector<xla::XlaOp>> operator()(
360 absl::Span<const xla::XlaOp> values, xla::XlaBuilder* builder) const {
361 auto row_idx = values[0];
362 auto num_outputs_so_far = values[1];
363 auto iou_mask = values[2];
364 auto included_iou = values[3];
365 auto zero = xla::ConstantR0<int32>(builder, 0);
366 // Determine if current elem is active using a slice.
367 // TODO(b/118437727): The only reason we need an explicit vector is because
368 // some old GCCs can't deduce the right type for MakeConstSpan, and
369 // providing a single-value initializer list directly uses the wrong
370 // overload. Delete this once the deprecated overload is gone.
371 std::vector<xla::XlaOp> row_idx_vector = {row_idx};
372 auto active_elem = xla::DynamicSlice(included_iou, row_idx_vector, {1});
373 active_elem = xla::Reshape(active_elem, {});
374 // Increment output count iff current elem is not suppressed.
375 num_outputs_so_far = xla::Select(
376 active_elem, num_outputs_so_far + xla::ConstantR0<int32>(builder, 1),
377 num_outputs_so_far);
378 // Slice out the row_idx.
379 auto row_iou = xla::DynamicSlice(iou_mask, {row_idx, zero}, {1, num_boxes});
380 // Remove the diagonal from consideration. An elem cannot suppress
381 // itself.
382 row_iou = xla::DynamicUpdateSlice(
383 row_iou, xla::ConstantR2FromArray2D<bool>(builder, {{false}}),
384 {zero, row_idx});
385 // Create a suppression by inverting polarity.
386 row_iou = xla::Reshape(row_iou, {num_boxes});
387 auto supp_mask = xla::Not(row_iou);
388 // Update mask iff current elem is not suppressed.
389 included_iou = xla::Select(xla::Broadcast(active_elem, {num_boxes}),
390 xla::And(included_iou, supp_mask), included_iou);
391 row_idx = row_idx + xla::ConstantR0<int32>(builder, 1);
392 return std::vector<xla::XlaOp>{row_idx, num_outputs_so_far, iou_mask,
393 included_iou};
394 }
395 };
396
397 class NonMaxSuppressionOp : public XlaOpKernel {
398 public:
NonMaxSuppressionOp(OpKernelConstruction * context)399 explicit NonMaxSuppressionOp(OpKernelConstruction* context)
400 : XlaOpKernel(context) {
401 OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size",
402 &pad_to_max_output_size_));
403 }
404
Compile(XlaOpKernelContext * context)405 void Compile(XlaOpKernelContext* context) override {
406 // TODO(b/111646731): Improve scalability of this op, using blocking.
407 const TensorShape& boxes_shape = context->InputShape("boxes");
408 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(boxes_shape),
409 errors::InvalidArgument("boxes must be 2-D, currently: ",
410 boxes_shape.DebugString()));
411 const int64 num_boxes = boxes_shape.dim_size(0);
412 OP_REQUIRES(context, boxes_shape.dim_size(1) == 4,
413 errors::InvalidArgument("boxes must have 4 columns",
414 boxes_shape.DebugString()));
415 const TensorShape& scores_shape = context->InputShape("scores");
416 OP_REQUIRES(context, TensorShapeUtils::IsVector(scores_shape),
417 errors::InvalidArgument("scores must be 1-D, currently: ",
418 scores_shape.DebugString()));
419 OP_REQUIRES(
420 context, scores_shape.dim_size(0) == num_boxes,
421 errors::InvalidArgument("scores size must equal number of boxes",
422 scores_shape.DebugString()));
423 OP_REQUIRES(context, pad_to_max_output_size_,
424 errors::InvalidArgument(
425 "XLA compilation requires pad_to_max_output_size == True"));
426 OP_REQUIRES(context, num_boxes <= kint32max,
427 errors::InvalidArgument("XLA compilation requires number of "
428 "boxes to be <= kint32max, got ",
429 num_boxes));
430
431 const xla::XlaOp boxes_input = context->Input("boxes");
432 const xla::XlaOp scores_input = context->Input("scores");
433 int64 output_size;
434 OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &output_size));
435 OP_REQUIRES(
436 context, output_size >= 0,
437 errors::InvalidArgument("Need output_size >= 0, got ", output_size));
438 OP_REQUIRES(context, output_size <= kint32max,
439 errors::InvalidArgument("Need output_size <= kint32Max, got ",
440 output_size));
441 const xla::XlaOp score_thresh = context->Input("score_threshold");
442 const xla::XlaOp iou_thresh = context->Input("iou_threshold");
443 xla::XlaBuilder* const builder = context->builder();
444
445 // Choose a more convenient layout.
446 const xla::XlaOp boxes = xla::Transpose(boxes_input, {1, 0});
447 const xla::XlaOp boxes_sorted = xla::GetTupleElement(
448 xla::Sort(/*keys=*/-xla::Broadcast(scores_input, {4}),
449 /*values=*/{boxes},
450 /*dimension=*/1),
451 1);
452 // Track the mapping of indices into sorted domain.
453 const xla::XlaOp iota_indices = xla::Iota(builder, xla::S32, num_boxes);
454 const xla::XlaOp indices_sort = xla::Sort(-scores_input, {iota_indices});
455 const xla::XlaOp indices_sorted = xla::GetTupleElement(indices_sort, 1);
456 const xla::XlaOp scores = xla::Neg(xla::GetTupleElement(indices_sort, 0));
457
458 // Shapes are henceforth [1, num_boxes]. 'c_y0' denotes 'coordinate' y0.
459 const xla::XlaOp c_y0 = xla::Reshape(xla::SliceInDim(boxes_sorted,
460 /*start_index=*/0,
461 /*limit_index=*/1,
462 /*stride=*/1,
463 /*dimno=*/0),
464 {num_boxes});
465 const xla::XlaOp c_x0 = xla::Reshape(xla::SliceInDim(boxes_sorted,
466 /*start_index=*/1,
467 /*limit_index=*/2,
468 /*stride=*/1,
469 /*dimno=*/0),
470 {num_boxes});
471 const xla::XlaOp c_y1 = xla::Reshape(xla::SliceInDim(boxes_sorted,
472 /*start_index=*/2,
473 /*limit_index=*/3,
474 /*stride=*/1,
475 /*dimno=*/0),
476 {num_boxes});
477 const xla::XlaOp c_x1 = xla::Reshape(xla::SliceInDim(boxes_sorted,
478 /*start_index=*/3,
479 /*limit_index=*/4,
480 /*stride=*/1,
481 /*dimno=*/0),
482 {num_boxes});
483
484 xla::XlaOp y1 = xla::Select(xla::Le(c_y0, c_y1), c_y0, c_y1);
485 xla::XlaOp y2 = xla::Select(xla::Le(c_y0, c_y1), c_y1, c_y0);
486 xla::XlaOp x1 = xla::Select(xla::Le(c_x0, c_x1), c_x0, c_x1);
487 xla::XlaOp x2 = xla::Select(xla::Le(c_x0, c_x1), c_x1, c_x0);
488 xla::XlaOp area = (y2 - y1) * (x2 - x1);
489
490 // Shapes are henceforth [1, num_boxes].
491 y1 = xla::Broadcast(y1, {1});
492 y2 = xla::Broadcast(y2, {1});
493 x1 = xla::Broadcast(x1, {1});
494 x2 = xla::Broadcast(x2, {1});
495 area = xla::Broadcast(area, {1});
496
497 // Shapes are henceforth [num_boxes, num_boxes].
498 xla::XlaOp i_xmin = xla::Max(x1, xla::Transpose(x1, {1, 0}));
499 xla::XlaOp i_ymin = xla::Max(y1, xla::Transpose(y1, {1, 0}));
500 xla::XlaOp i_xmax = xla::Min(x2, xla::Transpose(x2, {1, 0}));
501 xla::XlaOp i_ymax = xla::Min(y2, xla::Transpose(y2, {1, 0}));
502 auto square_zero = xla::ZerosLike(i_xmin);
503
504 xla::XlaOp i_area = xla::Max(i_xmax - i_xmin, square_zero) *
505 xla::Max(i_ymax - i_ymin, square_zero);
506 xla::XlaOp u_area = area + xla::Transpose(area, {1, 0}) - i_area;
507 xla::XlaOp iou = i_area / u_area;
508
509 xla::XlaOp iou_thresh_mask = xla::Gt(iou, iou_thresh + square_zero);
510 xla::XlaOp included_iou =
511 xla::Broadcast(xla::ConstantR0<bool>(builder, true), {num_boxes});
512
513 std::vector<xla::XlaOp> init_values;
514 init_values.reserve(4);
515 init_values.push_back(xla::ConstantR0<int32>(builder, 0)); // col_idx
516 init_values.push_back(xla::ConstantR0<int32>(builder, 0)); // num_outputs
517 init_values.push_back(iou_thresh_mask);
518 init_values.push_back(included_iou);
519
520 auto suppress_loop_result =
521 xla::WhileLoopHelper(WhileCondFn(num_boxes, output_size),
522 SuppressBodyFn(num_boxes), init_values,
523 "suppress_loop", builder)
524 .ValueOrDie();
525
526 xla::XlaOp included_score =
527 xla::Gt(scores, xla::Broadcast(score_thresh, {num_boxes}));
528 xla::XlaOp included = xla::And(included_score, suppress_loop_result[3]);
529
530 // Only consider boxes over which we have iterated. This allows for accurate
531 // counting. DynamicSlice would require knowledge of the size of the output.
532 auto valid_elem = xla::Lt(
533 iota_indices, xla::Broadcast(suppress_loop_result[0], {num_boxes}));
534 included = xla::And(included, valid_elem);
535
536 xla::XlaOp neg_inf =
537 xla::Broadcast(xla::MinValue(builder, xla::F32), {num_boxes});
538 xla::XlaOp scores_included = xla::Select(included, scores, neg_inf);
539 xla::XlaOp output_tuple = TopK(scores_included, output_size);
540 xla::XlaOp selected_indices_sorted = xla::GetTupleElement(output_tuple, 1);
541 // Calculate num_valid.
542 // Note: num_valid cannot be taken from the loop outputs, because outputs
543 // can be suppressed by score threshold.
544 xla::XlaOp ones_included = xla::Select(
545 included,
546 xla::Broadcast(xla::ConstantR0<int32>(builder, 1), {num_boxes}),
547 xla::Broadcast(xla::ConstantR0<int32>(builder, 0), {num_boxes}));
548 // num_valid is scalar. Value should be bound by output_size.
549 xla::XlaOp num_valid_total = xla::Reduce(
550 ones_included,
551 /*init_value=*/xla::ConstantR0<int>(builder, 0),
552 /*computation=*/CreateScalarAddComputation(xla::S32, builder),
553 /*dimensions_to_reduce=*/{0});
554 xla::XlaOp num_valid =
555 xla::Min(num_valid_total, xla::ConstantR0<int32>(builder, output_size));
556
557 // Re-index into the original scores input tensor, using a Gather.
558 // Boxes were suppressed in the sorted domain.
559 xla::XlaOp selected_indices;
560 DataType gather_type = context->expected_output_dtype(0);
561 OP_REQUIRES_OK(
562 context,
563 XlaGather(indices_sorted, scores_shape, selected_indices_sorted,
564 TensorShape({output_size}),
565 /*axis=*/0,
566 /*indices_are_nd=*/false,
567 /*dtype=*/gather_type, DT_INT32, builder, &selected_indices));
568
569 context->SetOutput(0, selected_indices);
570 context->SetOutput(1, num_valid);
571 }
572
573 private:
574 bool pad_to_max_output_size_;
575 };
576
577 REGISTER_XLA_OP(
578 Name("NonMaxSuppressionV4").CompileTimeConstantInput("max_output_size"),
579 NonMaxSuppressionOp);
580
581 } // namespace
582 } // namespace tensorflow
583