• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/tpu/tpu_defs.h"
17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
19 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
20 #include "tensorflow/compiler/xla/client/xla_builder.h"
21 
22 namespace tensorflow {
23 namespace {
24 
25 using ::tensorflow::errors::InvalidArgument;
26 
27 // Computes the Kth order statistic of a data set. The current
28 // implementation uses a binary search requiring exactly 32 passes
29 // over the input data. The running time is linear with respect to
30 // input size. The median-of-medians algorithm is probably faster, but
31 // is difficult to implement efficiently in XLA. The implementation
32 // imposes a total ordering on floats. The ordering is consistent with
33 // the usual partial order.  Positive NaNs are greater than positive
34 // infinity. Negative NaNs are less than negative infinity. NaNs with
35 // distinct payloads are treated as distinct. Subnormal numbers are
36 // preserved (not flushed to zero). Positive infinity is greater than
37 // all numbers. Negative infinity is less than all numbers. Positive
38 // is greater than negative zero. There are less than k values greater
39 // than the kth order statistic. There are at least k values greater
40 // than or equal to the Kth order statistic. The semantics are not the
41 // same as TopKUnique.
CreateKthOrderStatisticComputation(xla::XlaBuilder * builder,const TensorShape & input_shape,const xla::XlaOp input,const xla::XlaOp k)42 xla::XlaOp CreateKthOrderStatisticComputation(xla::XlaBuilder* builder,
43                                               const TensorShape& input_shape,
44                                               const xla::XlaOp input,
45                                               const xla::XlaOp k) {
46   const int64_t height = input_shape.dim_size(0);
47   const int64_t width = input_shape.dim_size(1);
48 
49   xla::XlaOp input_sm32 = xla::BitcastConvertType(input, xla::S32);
50   xla::XlaOp zero_r0 = xla::ConstantR0<int32>(builder, 0);
51   xla::XlaOp zero_r1 = xla::Broadcast(zero_r0, {height});
52   xla::XlaOp zero_r2 = xla::Broadcast(zero_r0, {height, width});
53 
54   xla::XlaOp max_r0 = xla::ConstantR0<int32>(builder, 0x7FFFFFFF);
55   xla::XlaOp max_r1 = xla::Broadcast(max_r0, {height});
56 
57   // Start at positive zero, so that pivot is always less than top.
58   xla::XlaOp negative_zero_r0 = xla::ConstantR0<int32>(builder, 0x80000000);
59   xla::XlaOp negative_zero_r1 = xla::Broadcast(negative_zero_r0, {height});
60   xla::XlaOp top_r1 = zero_r1;
61 
62   for (uint32 mask = 1U << 31; mask; mask >>= 1) {
63     xla::XlaOp broadcast_mask_r1 =
64         xla::Broadcast(xla::ConstantR0<int32>(builder, mask), {height});
65 
66     // The first iteration of the loop determines if the kth element
67     // is positive or negative. If the kth element is negative, we
68     // start the search from +QNAN (0x7FFFFFF). If k is negative, we
69     // start from -0 (0x8000000). The pivot is less than the top and
70     // is always half way between the top and the implicit bottom in
71     // IEEE754 space.
72     xla::XlaOp pivot_r1 = xla::Xor(top_r1, broadcast_mask_r1);
73     xla::XlaOp pivot_r2 = xla::Add(pivot_r1, zero_r2, {0});
74     xla::XlaOp both_negative_r2 =
75         xla::Lt(xla::And(input_sm32, pivot_r2), zero_r0);
76     xla::XlaOp left_r2 = xla::Select(both_negative_r2, pivot_r2, input_sm32);
77     xla::XlaOp right_r2 = xla::Select(both_negative_r2, input_sm32, pivot_r2);
78     xla::XlaOp pred_r2 = xla::Gt(left_r2, right_r2);
79     xla::XlaOp conv_r2 = xla::ConvertElementType(pred_r2, xla::S32);
80 
81     xla::XlaComputation add = CreateScalarAddComputation(xla::S32, builder);
82     xla::XlaOp sum_r1 = xla::Reduce(conv_r2, zero_r0, add, {1});
83 
84     xla::XlaOp pivot_too_low_r1 = xla::Le(k, sum_r1, {});
85 
86     if (mask == (1U << 31)) {
87       top_r1 = xla::Select(pivot_too_low_r1, max_r1, negative_zero_r1);
88     } else {
89       top_r1 = xla::Select(pivot_too_low_r1, top_r1, pivot_r1);
90     }
91   }
92   return xla::BitcastConvertType(top_r1, xla::F32);
93 }
94 
95 class KthOrderStatistic : public XlaOpKernel {
96  public:
KthOrderStatistic(OpKernelConstruction * ctx)97   explicit KthOrderStatistic(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
98     OP_REQUIRES_OK(ctx, ctx->GetAttr("k", &k_));
99     OP_REQUIRES(ctx, k_ >= 0, errors::InvalidArgument("Need k >= 0, got ", k_));
100   }
101 
Compile(XlaOpKernelContext * ctx)102   void Compile(XlaOpKernelContext* ctx) override {
103     xla::XlaBuilder* builder = ctx->builder();
104     xla::XlaOp input = ctx->Input(0);
105     const TensorShape& input_shape = ctx->InputShape(0);
106     OP_REQUIRES(
107         ctx, input_shape.dims() == 2,
108         InvalidArgument("input must be rank-2: ", input_shape.DebugString()));
109 
110     xla::XlaOp k = xla::ConstantR0<int32>(builder, k_);
111     xla::XlaOp kth_order_statistics =
112         CreateKthOrderStatisticComputation(builder, input_shape, input, k);
113     ctx->SetOutput(0, kth_order_statistics);
114   }
115 
116  private:
117   int32 k_;
118 };
119 
120 REGISTER_XLA_OP(Name("KthOrderStatistic"), KthOrderStatistic);
121 
122 // Returns the TopK unique values in the array in sorted order and the
123 // indices of those elements. The running time is proportional to the
124 // product of K and the input size. Sorting the whole array is more
125 // efficient for sufficiently large values of K. The median-of-medians
126 // algorithm is probably faster, but difficult to implement
127 // efficiently in XLA. If there are fewer than K unique values, the
128 // results are padded with negative infinity. NaNs are never
129 // returned. Subnormal numbers are flushed to zero.
130 //
131 // If an element appears at multiple indices, the highest index is
132 // returned. If a TopK element never appears in the input due to
133 // padding values, the indices are padded with negative one. If a
134 // padding value appears in the input and padding is needed, the
135 // highest index of the padding value will be returned.
136 //
137 // The semantics are not the same as KthOrderStatistic.
138 //
139 // If masked_with_iota is true, the index is already encoded in the lower bits
140 // of the mantissa, which will be extracted as the index in the output.
141 // Otherwise, every iteration will use the following algorithm to get the index:
142 //   index = max([i if data[i] == max else -1 for i in size])
143 //
144 // TODO(b/74994968): Replace TopKUnique with an LLO implementation of
145 // TopK with reasonable semantics.
CreateTopKUnique(xla::XlaBuilder * builder,const xla::XlaOp input,const TensorShape & input_shape,int64_t k,bool masked_with_iota)146 std::pair<xla::XlaOp, xla::XlaOp> CreateTopKUnique(
147     xla::XlaBuilder* builder, const xla::XlaOp input,
148     const TensorShape& input_shape, int64_t k, bool masked_with_iota) {
149   const int64_t height = input_shape.dim_size(0);
150   const int64_t width = input_shape.dim_size(1);
151 
152   xla::XlaOp iota_r1 = xla::Iota(builder, xla::S32, width);
153   xla::XlaOp iota_r2 = xla::Broadcast(iota_r1, {height});
154 
155   xla::XlaOp negative_one_r0 = xla::ConstantR0<int>(builder, -1);
156   xla::XlaOp negative_one_r2 = xla::Broadcast(negative_one_r0, {height, width});
157 
158   xla::XlaOp negative_infinity_r0 = xla::ConstantR0<float>(builder, -INFINITY);
159   xla::XlaOp negative_infinity_r2 =
160       xla::Broadcast(negative_infinity_r0, {height, width});
161 
162   xla::XlaOp scratch_pad_r2 = input;
163   std::vector<xla::XlaOp> topk_r1s;
164   std::vector<xla::XlaOp> topk_indices;
165   for (int i = 0; i < k; ++i) {
166     xla::XlaOp kth_order_statistic_r1 =
167         xla::Reduce(scratch_pad_r2, negative_infinity_r0,
168                     CreateScalarMaxComputation(xla::F32, builder), {1});
169     topk_r1s.push_back(kth_order_statistic_r1);
170 
171     xla::XlaOp ge_r2 = xla::Ge(input, kth_order_statistic_r1, {0});
172     scratch_pad_r2 = xla::Select(ge_r2, negative_infinity_r2, input);
173 
174     if (!masked_with_iota) {
175       xla::XlaOp eq_r2 = xla::Eq(input, kth_order_statistic_r1, {0});
176       xla::XlaOp indices_r2 = xla::Select(eq_r2, iota_r2, negative_one_r2);
177       xla::XlaOp topk_index_r1 =
178           xla::Reduce(indices_r2, negative_one_r0,
179                       CreateScalarMaxComputation(xla::S32, builder), {1});
180       topk_indices.push_back(topk_index_r1);
181     }
182   }
183   xla::XlaOp topk_r1_concat = xla::ConcatInDim(builder, topk_r1s, 0);
184   xla::XlaOp topk_r2 =
185       xla::Transpose(xla::Reshape(topk_r1_concat, {k, height}), {1, 0});
186 
187   xla::XlaOp topk_indices_r2;
188   if (masked_with_iota) {
189     int32_t log2_ceiling = tensorflow::Log2Ceiling(width);
190     int32_t next_power_of_two = 1U << log2_ceiling;
191     int32_t count_mask = next_power_of_two - 1;
192     xla::XlaOp mask_r0 = xla::ConstantR0(builder, count_mask);
193     xla::XlaOp mask_r2 = xla::Broadcast(mask_r0, {height, k});
194     xla::XlaOp topk_r2_s32 = xla::BitcastConvertType(topk_r2, xla::S32);
195     topk_indices_r2 = xla::And(topk_r2_s32, mask_r2);
196   } else {
197     xla::XlaOp topk_indices_concat = xla::ConcatInDim(builder, topk_indices, 0);
198     topk_indices_r2 =
199         xla::Transpose(xla::Reshape(topk_indices_concat, {k, height}), {1, 0});
200   }
201   return std::make_pair(topk_r2, topk_indices_r2);
202 }
203 
204 class TopKUnique : public XlaOpKernel {
205  public:
TopKUnique(OpKernelConstruction * ctx)206   explicit TopKUnique(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
207     OP_REQUIRES_OK(ctx, ctx->GetAttr("k", &k_));
208     OP_REQUIRES(ctx, k_ >= 0, errors::InvalidArgument("Need k >= 0, got ", k_));
209   }
210 
Compile(XlaOpKernelContext * ctx)211   void Compile(XlaOpKernelContext* ctx) override {
212     xla::XlaBuilder* builder = ctx->builder();
213     xla::XlaOp input = ctx->Input(0);
214     const TensorShape& input_shape = ctx->InputShape(0);
215     OP_REQUIRES(
216         ctx, input_shape.dims() == 2,
217         InvalidArgument("input must be rank-2: ", input_shape.DebugString()));
218 
219     auto topk = CreateTopKUnique(builder, input, input_shape, k_, false);
220     ctx->SetOutput(0, topk.first);
221     ctx->SetOutput(1, topk.second);
222   }
223 
224  private:
225   int k_;
226 };
227 REGISTER_XLA_OP(Name("TopKUnique"), TopKUnique);
228 
229 // Make all elements in the non-Batch dimension unique and close to
230 // their initial value on a relative scale, but potential far from
231 // their initial value in an absolute scale.
232 //
233 // This operation is meant to be combined with TopKUnique to avoid
234 // suppressing identical elements. For most TopK users, the indices of
235 // the TopK elements are important but the relative order of the TopK
236 // elements and their exact values is not so important. Ideally, the
237 // the indices of the TopK elements of the output of MakeUnique are
238 // the same as the indices of the TopK elements of the inputs.
239 //
240 // Its an open question whether it is better to accept the risk of two
241 // elements in the input to TopK have exactly the same value or the
242 // risk that MakeUnique will alter the indices of the TopK
243 // elements. Model owners are encouraged to experiment!
244 //
245 // Never returns a sub-normal number. Never returns zero. The sign of
246 // each input element is always identical to the sign of the
247 // corresponding output element. Behavior for infinite elements is
248 // undefined. Behavior for subnormal elements is undefined.
249 //
250 // Algorithm:
251 // 1. Replace zeros with the smallest representable normal floating
252 // point number with the same sign.
253 // 2. Mask away enough low order bits that every value can be distinct.
254 // 3. Replace the low order bits with iota.
255 //
256 // TODO(b/74994968): Replace MakeUnique with an LLO implementation of
257 // TopK with reasonable semantics.
CreateMakeUnique(xla::XlaBuilder * builder,const xla::XlaOp input,const TensorShape & input_shape)258 xla::XlaOp CreateMakeUnique(xla::XlaBuilder* builder, const xla::XlaOp input,
259                             const TensorShape& input_shape) {
260   const int64_t height = input_shape.dim_size(0);
261   const int64_t width = input_shape.dim_size(1);
262 
263   xla::XlaOp zero_r0 = xla::ConstantR0(builder, 0U);
264   xla::XlaOp zero_r2 = xla::Broadcast(zero_r0, {height, width});
265 
266   // count_mask is used to mask away the low order bits to ensure
267   // that every element is distinct.
268   uint32 log2_ceiling = static_cast<uint32>(std::ceil(std::log2(width)));
269   uint32 next_power_of_two = 1U << log2_ceiling;
270   uint32 count_mask = ~(next_power_of_two - 1);
271   xla::XlaOp count_mask_r0 = xla::ConstantR0(builder, count_mask);
272   xla::XlaOp count_mask_r2 = xla::Broadcast(count_mask_r0, {height, width});
273 
274   // smallest_normal is the bit representation of the smallest
275   // positive normal floating point number. The sign is zero,
276   // exponent is one, and the fraction is zero.
277   uint32 smallest_normal = 1U << 23;
278   xla::XlaOp smallest_normal_r0 = xla::ConstantR0(builder, smallest_normal);
279   xla::XlaOp smallest_normal_r2 =
280       xla::Broadcast(smallest_normal_r0, {height, width});
281 
282   // Used to mask away the sign bit when computing the absolute
283   // value.
284   uint32 low_bit_mask = ~(1U << 31);
285   xla::XlaOp low_bit_mask_r0 = xla::ConstantR0(builder, low_bit_mask);
286   xla::XlaOp low_bit_mask_r2 = xla::Broadcast(low_bit_mask_r0, {height, width});
287 
288   xla::XlaOp iota_r1 = xla::Iota(builder, xla::U32, width);
289   xla::XlaOp iota_r2 = xla::Broadcast(iota_r1, {height});
290 
291   // Compare the absolute value with positive zero to handle
292   // negative zero.
293   //
294   // Pseudocode: input_no_zeros = abs(input) == 0 ? FLT_MIN : input
295   xla::XlaOp input_u32_r2 = xla::BitcastConvertType(input, xla::U32);
296   xla::XlaOp abs_r2 = xla::And(input_u32_r2, low_bit_mask_r2);
297   xla::XlaOp if_zero_r2 = xla::Eq(abs_r2, zero_r2);
298   xla::XlaOp smallest_normal_preserving_sign_r2 =
299       xla::Or(input_u32_r2, smallest_normal_r2);
300   xla::XlaOp input_no_zeros_r2 =
301       xla::Select(if_zero_r2, smallest_normal_preserving_sign_r2, input_u32_r2);
302 
303   // Discard the low-order bits and replace with iota.
304   xla::XlaOp and_r2 = xla::And(input_no_zeros_r2, count_mask_r2);
305   xla::XlaOp or_r2 = xla::Or(and_r2, iota_r2);
306   return xla::BitcastConvertType(or_r2, xla::F32);
307 }
308 
309 class MakeUnique : public XlaOpKernel {
310  public:
MakeUnique(OpKernelConstruction * ctx)311   explicit MakeUnique(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
312 
Compile(XlaOpKernelContext * ctx)313   void Compile(XlaOpKernelContext* ctx) override {
314     xla::XlaBuilder* builder = ctx->builder();
315     xla::XlaOp input = ctx->Input(0);
316     const TensorShape& input_shape = ctx->InputShape(0);
317     OP_REQUIRES(
318         ctx, input_shape.dims() == 2,
319         InvalidArgument("input must be rank-2: ", input_shape.DebugString()));
320 
321     ctx->SetOutput(0, CreateMakeUnique(builder, input, input_shape));
322   }
323 };
324 REGISTER_XLA_OP(Name("MakeUnique"), MakeUnique);
325 
326 // Returns the TopK approximate values in the array in sorted order and the
327 // indices of those elements. The running time is proportional to the
328 // product of K and the input size.
329 //
330 // The algorithm first updates the lower bits of each element with iota,
331 // which is used to derive the index. The iota also serves the purpose to
332 // make each element unique so that each iteration, we are guaranteed to
333 // get one and only one unique top-1 element.
334 class TopKWithUnique : public XlaOpKernel {
335  public:
TopKWithUnique(OpKernelConstruction * ctx)336   explicit TopKWithUnique(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
337     OP_REQUIRES_OK(ctx, ctx->GetAttr("k", &k_));
338     OP_REQUIRES(ctx, k_ >= 0, errors::InvalidArgument("Need k >= 0, got ", k_));
339   }
340 
Compile(XlaOpKernelContext * ctx)341   void Compile(XlaOpKernelContext* ctx) override {
342     xla::XlaBuilder* builder = ctx->builder();
343     xla::XlaOp input = ctx->Input(0);
344     const TensorShape& input_shape = ctx->InputShape(0);
345     OP_REQUIRES(
346         ctx, input_shape.dims() == 2,
347         InvalidArgument("input must be rank-2: ", input_shape.DebugString()));
348 
349     xla::XlaOp unique = CreateMakeUnique(builder, input, input_shape);
350     auto topk = CreateTopKUnique(builder, unique, input_shape, k_, true);
351     ctx->SetOutput(0, topk.first);
352     ctx->SetOutput(1, topk.second);
353   }
354 
355  private:
356   int k_;
357 };
358 REGISTER_XLA_OP(Name("TopKWithUnique"), TopKWithUnique);
359 }  // namespace
360 }  // namespace tensorflow
361