1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/tpu/tpu_defs.h"
17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
19 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
20 #include "tensorflow/compiler/xla/client/xla_builder.h"
21
22 namespace tensorflow {
23 namespace {
24
25 using ::tensorflow::errors::InvalidArgument;
26
27 // Computes the Kth order statistic of a data set. The current
28 // implementation uses a binary search requiring exactly 32 passes
29 // over the input data. The running time is linear with respect to
30 // input size. The median-of-medians algorithm is probably faster, but
31 // is difficult to implement efficiently in XLA. The implementation
32 // imposes a total ordering on floats. The ordering is consistent with
33 // the usual partial order. Positive NaNs are greater than positive
34 // infinity. Negative NaNs are less than negative infinity. NaNs with
35 // distinct payloads are treated as distinct. Subnormal numbers are
36 // preserved (not flushed to zero). Positive infinity is greater than
37 // all numbers. Negative infinity is less than all numbers. Positive
38 // is greater than negative zero. There are less than k values greater
39 // than the kth order statistic. There are at least k values greater
40 // than or equal to the Kth order statistic. The semantics are not the
41 // same as TopKUnique.
CreateKthOrderStatisticComputation(xla::XlaBuilder * builder,const TensorShape & input_shape,const xla::XlaOp input,const xla::XlaOp k)42 xla::XlaOp CreateKthOrderStatisticComputation(xla::XlaBuilder* builder,
43 const TensorShape& input_shape,
44 const xla::XlaOp input,
45 const xla::XlaOp k) {
46 const int64_t height = input_shape.dim_size(0);
47 const int64_t width = input_shape.dim_size(1);
48
49 xla::XlaOp input_sm32 = xla::BitcastConvertType(input, xla::S32);
50 xla::XlaOp zero_r0 = xla::ConstantR0<int32>(builder, 0);
51 xla::XlaOp zero_r1 = xla::Broadcast(zero_r0, {height});
52 xla::XlaOp zero_r2 = xla::Broadcast(zero_r0, {height, width});
53
54 xla::XlaOp max_r0 = xla::ConstantR0<int32>(builder, 0x7FFFFFFF);
55 xla::XlaOp max_r1 = xla::Broadcast(max_r0, {height});
56
57 // Start at positive zero, so that pivot is always less than top.
58 xla::XlaOp negative_zero_r0 = xla::ConstantR0<int32>(builder, 0x80000000);
59 xla::XlaOp negative_zero_r1 = xla::Broadcast(negative_zero_r0, {height});
60 xla::XlaOp top_r1 = zero_r1;
61
62 for (uint32 mask = 1U << 31; mask; mask >>= 1) {
63 xla::XlaOp broadcast_mask_r1 =
64 xla::Broadcast(xla::ConstantR0<int32>(builder, mask), {height});
65
66 // The first iteration of the loop determines if the kth element
67 // is positive or negative. If the kth element is negative, we
68 // start the search from +QNAN (0x7FFFFFF). If k is negative, we
69 // start from -0 (0x8000000). The pivot is less than the top and
70 // is always half way between the top and the implicit bottom in
71 // IEEE754 space.
72 xla::XlaOp pivot_r1 = xla::Xor(top_r1, broadcast_mask_r1);
73 xla::XlaOp pivot_r2 = xla::Add(pivot_r1, zero_r2, {0});
74 xla::XlaOp both_negative_r2 =
75 xla::Lt(xla::And(input_sm32, pivot_r2), zero_r0);
76 xla::XlaOp left_r2 = xla::Select(both_negative_r2, pivot_r2, input_sm32);
77 xla::XlaOp right_r2 = xla::Select(both_negative_r2, input_sm32, pivot_r2);
78 xla::XlaOp pred_r2 = xla::Gt(left_r2, right_r2);
79 xla::XlaOp conv_r2 = xla::ConvertElementType(pred_r2, xla::S32);
80
81 xla::XlaComputation add = CreateScalarAddComputation(xla::S32, builder);
82 xla::XlaOp sum_r1 = xla::Reduce(conv_r2, zero_r0, add, {1});
83
84 xla::XlaOp pivot_too_low_r1 = xla::Le(k, sum_r1, {});
85
86 if (mask == (1U << 31)) {
87 top_r1 = xla::Select(pivot_too_low_r1, max_r1, negative_zero_r1);
88 } else {
89 top_r1 = xla::Select(pivot_too_low_r1, top_r1, pivot_r1);
90 }
91 }
92 return xla::BitcastConvertType(top_r1, xla::F32);
93 }
94
95 class KthOrderStatistic : public XlaOpKernel {
96 public:
KthOrderStatistic(OpKernelConstruction * ctx)97 explicit KthOrderStatistic(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
98 OP_REQUIRES_OK(ctx, ctx->GetAttr("k", &k_));
99 OP_REQUIRES(ctx, k_ >= 0, errors::InvalidArgument("Need k >= 0, got ", k_));
100 }
101
Compile(XlaOpKernelContext * ctx)102 void Compile(XlaOpKernelContext* ctx) override {
103 xla::XlaBuilder* builder = ctx->builder();
104 xla::XlaOp input = ctx->Input(0);
105 const TensorShape& input_shape = ctx->InputShape(0);
106 OP_REQUIRES(
107 ctx, input_shape.dims() == 2,
108 InvalidArgument("input must be rank-2: ", input_shape.DebugString()));
109
110 xla::XlaOp k = xla::ConstantR0<int32>(builder, k_);
111 xla::XlaOp kth_order_statistics =
112 CreateKthOrderStatisticComputation(builder, input_shape, input, k);
113 ctx->SetOutput(0, kth_order_statistics);
114 }
115
116 private:
117 int32 k_;
118 };
119
120 REGISTER_XLA_OP(Name("KthOrderStatistic"), KthOrderStatistic);
121
122 // Returns the TopK unique values in the array in sorted order and the
123 // indices of those elements. The running time is proportional to the
124 // product of K and the input size. Sorting the whole array is more
125 // efficient for sufficiently large values of K. The median-of-medians
126 // algorithm is probably faster, but difficult to implement
127 // efficiently in XLA. If there are fewer than K unique values, the
128 // results are padded with negative infinity. NaNs are never
129 // returned. Subnormal numbers are flushed to zero.
130 //
131 // If an element appears at multiple indices, the highest index is
132 // returned. If a TopK element never appears in the input due to
133 // padding values, the indices are padded with negative one. If a
134 // padding value appears in the input and padding is needed, the
135 // highest index of the padding value will be returned.
136 //
137 // The semantics are not the same as KthOrderStatistic.
138 //
139 // If masked_with_iota is true, the index is already encoded in the lower bits
140 // of the mantissa, which will be extracted as the index in the output.
141 // Otherwise, every iteration will use the following algorithm to get the index:
142 // index = max([i if data[i] == max else -1 for i in size])
143 //
144 // TODO(b/74994968): Replace TopKUnique with an LLO implementation of
145 // TopK with reasonable semantics.
CreateTopKUnique(xla::XlaBuilder * builder,const xla::XlaOp input,const TensorShape & input_shape,int64_t k,bool masked_with_iota)146 std::pair<xla::XlaOp, xla::XlaOp> CreateTopKUnique(
147 xla::XlaBuilder* builder, const xla::XlaOp input,
148 const TensorShape& input_shape, int64_t k, bool masked_with_iota) {
149 const int64_t height = input_shape.dim_size(0);
150 const int64_t width = input_shape.dim_size(1);
151
152 xla::XlaOp iota_r1 = xla::Iota(builder, xla::S32, width);
153 xla::XlaOp iota_r2 = xla::Broadcast(iota_r1, {height});
154
155 xla::XlaOp negative_one_r0 = xla::ConstantR0<int>(builder, -1);
156 xla::XlaOp negative_one_r2 = xla::Broadcast(negative_one_r0, {height, width});
157
158 xla::XlaOp negative_infinity_r0 = xla::ConstantR0<float>(builder, -INFINITY);
159 xla::XlaOp negative_infinity_r2 =
160 xla::Broadcast(negative_infinity_r0, {height, width});
161
162 xla::XlaOp scratch_pad_r2 = input;
163 std::vector<xla::XlaOp> topk_r1s;
164 std::vector<xla::XlaOp> topk_indices;
165 for (int i = 0; i < k; ++i) {
166 xla::XlaOp kth_order_statistic_r1 =
167 xla::Reduce(scratch_pad_r2, negative_infinity_r0,
168 CreateScalarMaxComputation(xla::F32, builder), {1});
169 topk_r1s.push_back(kth_order_statistic_r1);
170
171 xla::XlaOp ge_r2 = xla::Ge(input, kth_order_statistic_r1, {0});
172 scratch_pad_r2 = xla::Select(ge_r2, negative_infinity_r2, input);
173
174 if (!masked_with_iota) {
175 xla::XlaOp eq_r2 = xla::Eq(input, kth_order_statistic_r1, {0});
176 xla::XlaOp indices_r2 = xla::Select(eq_r2, iota_r2, negative_one_r2);
177 xla::XlaOp topk_index_r1 =
178 xla::Reduce(indices_r2, negative_one_r0,
179 CreateScalarMaxComputation(xla::S32, builder), {1});
180 topk_indices.push_back(topk_index_r1);
181 }
182 }
183 xla::XlaOp topk_r1_concat = xla::ConcatInDim(builder, topk_r1s, 0);
184 xla::XlaOp topk_r2 =
185 xla::Transpose(xla::Reshape(topk_r1_concat, {k, height}), {1, 0});
186
187 xla::XlaOp topk_indices_r2;
188 if (masked_with_iota) {
189 int32_t log2_ceiling = tensorflow::Log2Ceiling(width);
190 int32_t next_power_of_two = 1U << log2_ceiling;
191 int32_t count_mask = next_power_of_two - 1;
192 xla::XlaOp mask_r0 = xla::ConstantR0(builder, count_mask);
193 xla::XlaOp mask_r2 = xla::Broadcast(mask_r0, {height, k});
194 xla::XlaOp topk_r2_s32 = xla::BitcastConvertType(topk_r2, xla::S32);
195 topk_indices_r2 = xla::And(topk_r2_s32, mask_r2);
196 } else {
197 xla::XlaOp topk_indices_concat = xla::ConcatInDim(builder, topk_indices, 0);
198 topk_indices_r2 =
199 xla::Transpose(xla::Reshape(topk_indices_concat, {k, height}), {1, 0});
200 }
201 return std::make_pair(topk_r2, topk_indices_r2);
202 }
203
204 class TopKUnique : public XlaOpKernel {
205 public:
TopKUnique(OpKernelConstruction * ctx)206 explicit TopKUnique(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
207 OP_REQUIRES_OK(ctx, ctx->GetAttr("k", &k_));
208 OP_REQUIRES(ctx, k_ >= 0, errors::InvalidArgument("Need k >= 0, got ", k_));
209 }
210
Compile(XlaOpKernelContext * ctx)211 void Compile(XlaOpKernelContext* ctx) override {
212 xla::XlaBuilder* builder = ctx->builder();
213 xla::XlaOp input = ctx->Input(0);
214 const TensorShape& input_shape = ctx->InputShape(0);
215 OP_REQUIRES(
216 ctx, input_shape.dims() == 2,
217 InvalidArgument("input must be rank-2: ", input_shape.DebugString()));
218
219 auto topk = CreateTopKUnique(builder, input, input_shape, k_, false);
220 ctx->SetOutput(0, topk.first);
221 ctx->SetOutput(1, topk.second);
222 }
223
224 private:
225 int k_;
226 };
227 REGISTER_XLA_OP(Name("TopKUnique"), TopKUnique);
228
229 // Make all elements in the non-Batch dimension unique and close to
230 // their initial value on a relative scale, but potential far from
231 // their initial value in an absolute scale.
232 //
233 // This operation is meant to be combined with TopKUnique to avoid
234 // suppressing identical elements. For most TopK users, the indices of
235 // the TopK elements are important but the relative order of the TopK
236 // elements and their exact values is not so important. Ideally, the
237 // the indices of the TopK elements of the output of MakeUnique are
238 // the same as the indices of the TopK elements of the inputs.
239 //
240 // Its an open question whether it is better to accept the risk of two
241 // elements in the input to TopK have exactly the same value or the
242 // risk that MakeUnique will alter the indices of the TopK
243 // elements. Model owners are encouraged to experiment!
244 //
245 // Never returns a sub-normal number. Never returns zero. The sign of
246 // each input element is always identical to the sign of the
247 // corresponding output element. Behavior for infinite elements is
248 // undefined. Behavior for subnormal elements is undefined.
249 //
250 // Algorithm:
251 // 1. Replace zeros with the smallest representable normal floating
252 // point number with the same sign.
253 // 2. Mask away enough low order bits that every value can be distinct.
254 // 3. Replace the low order bits with iota.
255 //
256 // TODO(b/74994968): Replace MakeUnique with an LLO implementation of
257 // TopK with reasonable semantics.
CreateMakeUnique(xla::XlaBuilder * builder,const xla::XlaOp input,const TensorShape & input_shape)258 xla::XlaOp CreateMakeUnique(xla::XlaBuilder* builder, const xla::XlaOp input,
259 const TensorShape& input_shape) {
260 const int64_t height = input_shape.dim_size(0);
261 const int64_t width = input_shape.dim_size(1);
262
263 xla::XlaOp zero_r0 = xla::ConstantR0(builder, 0U);
264 xla::XlaOp zero_r2 = xla::Broadcast(zero_r0, {height, width});
265
266 // count_mask is used to mask away the low order bits to ensure
267 // that every element is distinct.
268 uint32 log2_ceiling = static_cast<uint32>(std::ceil(std::log2(width)));
269 uint32 next_power_of_two = 1U << log2_ceiling;
270 uint32 count_mask = ~(next_power_of_two - 1);
271 xla::XlaOp count_mask_r0 = xla::ConstantR0(builder, count_mask);
272 xla::XlaOp count_mask_r2 = xla::Broadcast(count_mask_r0, {height, width});
273
274 // smallest_normal is the bit representation of the smallest
275 // positive normal floating point number. The sign is zero,
276 // exponent is one, and the fraction is zero.
277 uint32 smallest_normal = 1U << 23;
278 xla::XlaOp smallest_normal_r0 = xla::ConstantR0(builder, smallest_normal);
279 xla::XlaOp smallest_normal_r2 =
280 xla::Broadcast(smallest_normal_r0, {height, width});
281
282 // Used to mask away the sign bit when computing the absolute
283 // value.
284 uint32 low_bit_mask = ~(1U << 31);
285 xla::XlaOp low_bit_mask_r0 = xla::ConstantR0(builder, low_bit_mask);
286 xla::XlaOp low_bit_mask_r2 = xla::Broadcast(low_bit_mask_r0, {height, width});
287
288 xla::XlaOp iota_r1 = xla::Iota(builder, xla::U32, width);
289 xla::XlaOp iota_r2 = xla::Broadcast(iota_r1, {height});
290
291 // Compare the absolute value with positive zero to handle
292 // negative zero.
293 //
294 // Pseudocode: input_no_zeros = abs(input) == 0 ? FLT_MIN : input
295 xla::XlaOp input_u32_r2 = xla::BitcastConvertType(input, xla::U32);
296 xla::XlaOp abs_r2 = xla::And(input_u32_r2, low_bit_mask_r2);
297 xla::XlaOp if_zero_r2 = xla::Eq(abs_r2, zero_r2);
298 xla::XlaOp smallest_normal_preserving_sign_r2 =
299 xla::Or(input_u32_r2, smallest_normal_r2);
300 xla::XlaOp input_no_zeros_r2 =
301 xla::Select(if_zero_r2, smallest_normal_preserving_sign_r2, input_u32_r2);
302
303 // Discard the low-order bits and replace with iota.
304 xla::XlaOp and_r2 = xla::And(input_no_zeros_r2, count_mask_r2);
305 xla::XlaOp or_r2 = xla::Or(and_r2, iota_r2);
306 return xla::BitcastConvertType(or_r2, xla::F32);
307 }
308
309 class MakeUnique : public XlaOpKernel {
310 public:
MakeUnique(OpKernelConstruction * ctx)311 explicit MakeUnique(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
312
Compile(XlaOpKernelContext * ctx)313 void Compile(XlaOpKernelContext* ctx) override {
314 xla::XlaBuilder* builder = ctx->builder();
315 xla::XlaOp input = ctx->Input(0);
316 const TensorShape& input_shape = ctx->InputShape(0);
317 OP_REQUIRES(
318 ctx, input_shape.dims() == 2,
319 InvalidArgument("input must be rank-2: ", input_shape.DebugString()));
320
321 ctx->SetOutput(0, CreateMakeUnique(builder, input, input_shape));
322 }
323 };
324 REGISTER_XLA_OP(Name("MakeUnique"), MakeUnique);
325
326 // Returns the TopK approximate values in the array in sorted order and the
327 // indices of those elements. The running time is proportional to the
328 // product of K and the input size.
329 //
330 // The algorithm first updates the lower bits of each element with iota,
331 // which is used to derive the index. The iota also serves the purpose to
332 // make each element unique so that each iteration, we are guaranteed to
333 // get one and only one unique top-1 element.
334 class TopKWithUnique : public XlaOpKernel {
335 public:
TopKWithUnique(OpKernelConstruction * ctx)336 explicit TopKWithUnique(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
337 OP_REQUIRES_OK(ctx, ctx->GetAttr("k", &k_));
338 OP_REQUIRES(ctx, k_ >= 0, errors::InvalidArgument("Need k >= 0, got ", k_));
339 }
340
Compile(XlaOpKernelContext * ctx)341 void Compile(XlaOpKernelContext* ctx) override {
342 xla::XlaBuilder* builder = ctx->builder();
343 xla::XlaOp input = ctx->Input(0);
344 const TensorShape& input_shape = ctx->InputShape(0);
345 OP_REQUIRES(
346 ctx, input_shape.dims() == 2,
347 InvalidArgument("input must be rank-2: ", input_shape.DebugString()));
348
349 xla::XlaOp unique = CreateMakeUnique(builder, input, input_shape);
350 auto topk = CreateTopKUnique(builder, unique, input_shape, k_, true);
351 ctx->SetOutput(0, topk.first);
352 ctx->SetOutput(1, topk.second);
353 }
354
355 private:
356 int k_;
357 };
358 REGISTER_XLA_OP(Name("TopKWithUnique"), TopKWithUnique);
359 } // namespace
360 } // namespace tensorflow
361