• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/sorting.h"
17 
18 #include "tensorflow/compiler/xla/client/lib/comparators.h"
19 #include "tensorflow/compiler/xla/client/lib/constants.h"
20 #include "tensorflow/compiler/xla/client/lib/loops.h"
21 #include "tensorflow/compiler/xla/client/lib/slicing.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/util.h"
25 
26 namespace xla {
27 
TopK(XlaOp input,int64_t k)28 XlaOp TopK(XlaOp input, int64_t k) {
29   XlaBuilder* const builder = input.builder();
30   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
31     TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
32     int last_dim = input_shape.dimensions_size() - 1;
33     int64_t last_dim_size = input_shape.dimensions(last_dim);
34     // TODO(b/148796364): tune these constants for better performance.
35     const int64_t kPerPartitionSize = 8192;        // 2^13
36     const int64_t kLastDimSizeThreshold = 524288;  // 2^19
37     const int64_t kMinNumPartitions = 8;
38     const int64_t kMinimalK = 1000;
39     if ((k >= kMinimalK) && (k < kPerPartitionSize) &&
40         (kPerPartitionSize / k > 2) && last_dim_size >= kLastDimSizeThreshold) {
41       int64_t num_partitions =
42           CeilOfRatio(last_dim_size - k, kPerPartitionSize - k);
43       if (num_partitions >= kMinNumPartitions) {
44         return TopKWithPartitions(input, k, num_partitions);
45       }
46     }
47 
48     Shape iota_shape = ShapeUtil::MakeShape(S32, input_shape.dimensions());
49     XlaOp iota_s32 = Iota(builder, iota_shape, last_dim);
50     for (int64_t i = 0; i < input_shape.rank(); ++i) {
51       if (input_shape.is_dynamic_dimension(i)) {
52         // Propagate dynamic dimension from inputs to iota.
53         iota_s32 = SetDimensionSize(iota_s32, GetDimensionSize(input, i), i);
54       }
55     }
56     auto input_dims = input_shape.dimensions();
57 
58     // We can pack BF16 values to be sorted along with their index values into a
59     // single 32-bit value in some cases.
60     constexpr int32_t kLow16BitsLimit = int32_t{1} << 16;
61     constexpr int32_t kLow16BitsMask = kLow16BitsLimit - 1;
62     constexpr int32_t kHigh16BitsMask = ~kLow16BitsMask;
63 
64     // Whether to use the packed sorting algorithm for BF16 data. This change is
65     // good in general, and enables a separate TPU optimization for common cases
66     // as well (top-k for small k).
67     constexpr int kMaxLastDimSizeForSmallBatches = 1500;
68     constexpr int kSmallBatchSizeThreshold = 8;
69     const bool use_packed_bf16_sort =
70         (input_shape.element_type() == BF16 &&
71          last_dim_size < kLow16BitsLimit &&
72          (last_dim_size < kMaxLastDimSizeForSmallBatches ||
73           (input_shape.rank() == 2 &&
74            input_shape.dimensions(0) >= kSmallBatchSizeThreshold)));
75 
76     std::vector<int64_t> start_indices(input_shape.dimensions_size(), 0);
77     std::vector<int64_t> limit_indices(input_dims.begin(), input_dims.end());
78     limit_indices[last_dim] = k;
79     std::vector<int64_t> strides(input_shape.dimensions_size(), 1);
80 
81     XlaOp values;
82     XlaOp indices;
83     if (use_packed_bf16_sort) {
84       // Converts a 32-bit value from sign-magnitude (used for floats) to one's
85       // complement (easy to compare using integer operations) or vice versa.
86       auto sign_magnitude_to_from_ones_complement = [builder](const XlaOp in) {
87         constexpr int32_t kAllNonSignBits = 0x7fffffff;
88         XlaOp in_s32 = BitcastConvertType(in, S32);
89         return Xor(
90             And(in_s32, ConstantR0<int32_t>(builder, kAllNonSignBits)),
91             ShiftRightArithmetic(in_s32, ConstantR0<int32_t>(builder, 31)));
92       };
93 
94       // Move input values to the high 16 bits of each 32-bit element, convert
95       // them to allow integer comparisons, set the low 16 bits to one (in order
96       // to reverse the sort order of the element indices), then XOR in the iota
97       // result. This leads to the ones' complement version of the BF16 input in
98       // the high 16 bits and the ones' complement of the indices in the low 16
99       // bits.
100       XlaOp input_f32_trimmed =
101           Or(sign_magnitude_to_from_ones_complement(
102                  BitcastConvertType(ConvertElementType(input, F32), S32)),
103              ConstantR0<int32_t>(builder, kLow16BitsMask));
104       XlaOp input_and_iota = Xor(input_f32_trimmed, iota_s32);
105 
106       // Sort in reverse order so the largest elements are at the beginning.
107       // Breaking ties here is why the index bits need to be inverted.
108       XlaOp sort_result_raw = Sort(
109           {input_and_iota}, CreateScalarGtComputation({S32}, builder), last_dim,
110           /*is_stable=*/false);
111 
112       // Slice off the first k values.
113       sort_result_raw =
114           Slice(sort_result_raw, start_indices, limit_indices, strides);
115       // The k in TopK is static so we shouldn't generate a dynamic dimension
116       // even if input is dynamic.
117       sort_result_raw = RemoveDynamicDimension(sort_result_raw, last_dim);
118 
119       // Get the high 16 bits of each value from the sorted result and convert
120       // them back to BF16.
121       values = ConvertElementType(
122           BitcastConvertType(
123               And(sign_magnitude_to_from_ones_complement(sort_result_raw),
124                   ConstantR0<int32_t>(builder, kHigh16BitsMask)),
125               F32),
126           BF16);
127 
128       // Get the index values from the low 16 bits of each value and invert them
129       // again.
130       indices = And(
131           Xor(sort_result_raw, ConstantR0<int32_t>(builder, kLow16BitsMask)),
132           ConstantR0<int32_t>(builder, kLow16BitsMask));
133     } else {
134       XlaOp sort_result =
135           Sort({input, iota_s32},
136                CreateScalarGtComputation({input_shape.element_type(), S32},
137                                          iota_s32.builder()),
138                last_dim, /*is_stable=*/true);
139       values = Slice(GetTupleElement(sort_result, 0), start_indices,
140                      limit_indices, strides);
141       // The k in TopK is static so we shouldn't generate a dynamic dimension
142       // even if input is dynamic.
143       values = RemoveDynamicDimension(values, last_dim);
144       indices = Slice(GetTupleElement(sort_result, 1), start_indices,
145                       limit_indices, strides);
146       indices = RemoveDynamicDimension(indices, last_dim);
147     }
148 
149     return Tuple(builder, {values, indices});
150   });
151 }
152 
TopKWithPartitions(XlaOp input,int64_t k,int64_t num_partitions)153 XlaOp TopKWithPartitions(XlaOp input, int64_t k, int64_t num_partitions) {
154   XlaBuilder* const builder = input.builder();
155   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
156     TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
157     int last_dim = input_shape.dimensions_size() - 1;
158     // Calculate per partition size.
159     auto input_dims = input_shape.dimensions();
160     int64_t last_dim_size = input_shape.dimensions(last_dim);
161     const int64_t per_partition_size =
162         CeilOfRatio(last_dim_size, num_partitions);
163     // Do normal TopK when per partition size is smaller than or equal to k.
164     if (k >= per_partition_size) {
165       return TopK(input, k);
166     }
167 
168     Shape iota_shape = ShapeUtil::MakeShape(S32, input_shape.dimensions());
169     XlaOp iota_s32 = Iota(builder, iota_shape, last_dim);
170     for (int64_t i = 0; i < input_shape.rank(); ++i) {
171       if (input_shape.is_dynamic_dimension(i)) {
172         // Propagate dynamic dimension from inputs to iota.
173         iota_s32 = SetDimensionSize(iota_s32, GetDimensionSize(input, i), i);
174       }
175     }
176 
177     auto topk_body_fn =
178         [&](XlaOp partition, absl::Span<const XlaOp> values_and_indices,
179             XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
180       auto values = values_and_indices[0];
181       auto indices = values_and_indices[1];
182       auto input = values_and_indices[2];
183       auto iota_s32 = values_and_indices[3];
184 
185       // Slice value and indices for this partition.
186       XlaOp start = Mul(Add(partition, ConstantR0<int32_t>(builder, 1)),
187                         ConstantR0<int32_t>(builder, per_partition_size));
188       XlaOp sliced_input =
189           DynamicSliceInMinorDims(input, {start}, {per_partition_size});
190       XlaOp sliced_indices =
191           DynamicSliceInMinorDims(iota_s32, {start}, {per_partition_size});
192       // Concat with previous results.
193       sliced_input = ConcatInDim(builder, {values, sliced_input}, last_dim);
194       sliced_indices =
195           ConcatInDim(builder, {indices, sliced_indices}, last_dim);
196       // Sort this slice
197       XlaOp sort_result =
198           Sort({sliced_input, sliced_indices},
199                CreateScalarGtComputation({input_shape.element_type(), S32},
200                                          sliced_indices.builder()),
201                last_dim, true);
202 
203       std::vector<int64_t> start_indices(input_shape.dimensions_size(), 0);
204       std::vector<int64_t> limit_indices(input_dims.begin(), input_dims.end());
205       std::vector<int64_t> strides(input_shape.dimensions_size(), 1);
206       // Slice topk.
207       start_indices[last_dim] = 0;
208       limit_indices[last_dim] = k;
209       values = Slice(GetTupleElement(sort_result, 0), start_indices,
210                      limit_indices, strides);
211       indices = Slice(GetTupleElement(sort_result, 1), start_indices,
212                       limit_indices, strides);
213       return std::vector<XlaOp>{values, indices, input, iota_s32};
214     };
215 
216     // Get the values and indices for the first topk so that they can
217     // be passed to the while loop.
218     std::vector<int64_t> start_indices(input_shape.dimensions_size(), 0);
219     std::vector<int64_t> limit_indices(input_dims.begin(), input_dims.end());
220     std::vector<int64_t> strides(input_shape.dimensions_size(), 1);
221     start_indices[last_dim] = 0;
222     limit_indices[last_dim] = per_partition_size;
223     // Slice value and indices for the first partition.
224     XlaOp sliced_input = Slice(input, start_indices, limit_indices, strides);
225     XlaOp sliced_indices =
226         Slice(iota_s32, start_indices, limit_indices, strides);
227     // Sort this slice
228     XlaOp sort_result =
229         Sort({sliced_input, sliced_indices},
230              CreateScalarGtComputation({input_shape.element_type(), S32},
231                                        sliced_indices.builder()),
232              last_dim, /*is_stable=*/true);
233 
234     // Slice topk.
235     start_indices[last_dim] = 0;
236     limit_indices[last_dim] = k;
237     XlaOp values = Slice(GetTupleElement(sort_result, 0), start_indices,
238                          limit_indices, strides);
239     XlaOp indices = Slice(GetTupleElement(sort_result, 1), start_indices,
240                           limit_indices, strides);
241 
242     // Pass the result of the first TopK to the while loop and do
243     // num_partition - 1 iterations.
244     TF_ASSIGN_OR_RETURN(auto values_and_indices,
245                         ForEachIndex(num_partitions - 1, S32, topk_body_fn,
246                                      {values, indices, input, iota_s32},
247                                      "topk_with_partition", builder));
248     return Tuple(builder, {values_and_indices[0], values_and_indices[1]});
249   });
250 }
251 
252 }  // namespace xla
253