1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/sorting.h"
17
18 #include "tensorflow/compiler/xla/client/lib/comparators.h"
19 #include "tensorflow/compiler/xla/client/lib/constants.h"
20 #include "tensorflow/compiler/xla/client/lib/loops.h"
21 #include "tensorflow/compiler/xla/client/lib/slicing.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/util.h"
25
26 namespace xla {
27
TopK(XlaOp input,int64_t k)28 XlaOp TopK(XlaOp input, int64_t k) {
29 XlaBuilder* const builder = input.builder();
30 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
31 TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
32 int last_dim = input_shape.dimensions_size() - 1;
33 int64_t last_dim_size = input_shape.dimensions(last_dim);
34 // TODO(b/148796364): tune these constants for better performance.
35 const int64_t kPerPartitionSize = 8192; // 2^13
36 const int64_t kLastDimSizeThreshold = 524288; // 2^19
37 const int64_t kMinNumPartitions = 8;
38 const int64_t kMinimalK = 1000;
39 if ((k >= kMinimalK) && (k < kPerPartitionSize) &&
40 (kPerPartitionSize / k > 2) && last_dim_size >= kLastDimSizeThreshold) {
41 int64_t num_partitions =
42 CeilOfRatio(last_dim_size - k, kPerPartitionSize - k);
43 if (num_partitions >= kMinNumPartitions) {
44 return TopKWithPartitions(input, k, num_partitions);
45 }
46 }
47
48 Shape iota_shape = ShapeUtil::MakeShape(S32, input_shape.dimensions());
49 XlaOp iota_s32 = Iota(builder, iota_shape, last_dim);
50 for (int64_t i = 0; i < input_shape.rank(); ++i) {
51 if (input_shape.is_dynamic_dimension(i)) {
52 // Propagate dynamic dimension from inputs to iota.
53 iota_s32 = SetDimensionSize(iota_s32, GetDimensionSize(input, i), i);
54 }
55 }
56 auto input_dims = input_shape.dimensions();
57
58 // We can pack BF16 values to be sorted along with their index values into a
59 // single 32-bit value in some cases.
60 constexpr int32_t kLow16BitsLimit = int32_t{1} << 16;
61 constexpr int32_t kLow16BitsMask = kLow16BitsLimit - 1;
62 constexpr int32_t kHigh16BitsMask = ~kLow16BitsMask;
63
64 // Whether to use the packed sorting algorithm for BF16 data. This change is
65 // good in general, and enables a separate TPU optimization for common cases
66 // as well (top-k for small k).
67 constexpr int kMaxLastDimSizeForSmallBatches = 1500;
68 constexpr int kSmallBatchSizeThreshold = 8;
69 const bool use_packed_bf16_sort =
70 (input_shape.element_type() == BF16 &&
71 last_dim_size < kLow16BitsLimit &&
72 (last_dim_size < kMaxLastDimSizeForSmallBatches ||
73 (input_shape.rank() == 2 &&
74 input_shape.dimensions(0) >= kSmallBatchSizeThreshold)));
75
76 std::vector<int64_t> start_indices(input_shape.dimensions_size(), 0);
77 std::vector<int64_t> limit_indices(input_dims.begin(), input_dims.end());
78 limit_indices[last_dim] = k;
79 std::vector<int64_t> strides(input_shape.dimensions_size(), 1);
80
81 XlaOp values;
82 XlaOp indices;
83 if (use_packed_bf16_sort) {
84 // Converts a 32-bit value from sign-magnitude (used for floats) to one's
85 // complement (easy to compare using integer operations) or vice versa.
86 auto sign_magnitude_to_from_ones_complement = [builder](const XlaOp in) {
87 constexpr int32_t kAllNonSignBits = 0x7fffffff;
88 XlaOp in_s32 = BitcastConvertType(in, S32);
89 return Xor(
90 And(in_s32, ConstantR0<int32_t>(builder, kAllNonSignBits)),
91 ShiftRightArithmetic(in_s32, ConstantR0<int32_t>(builder, 31)));
92 };
93
94 // Move input values to the high 16 bits of each 32-bit element, convert
95 // them to allow integer comparisons, set the low 16 bits to one (in order
96 // to reverse the sort order of the element indices), then XOR in the iota
97 // result. This leads to the ones' complement version of the BF16 input in
98 // the high 16 bits and the ones' complement of the indices in the low 16
99 // bits.
100 XlaOp input_f32_trimmed =
101 Or(sign_magnitude_to_from_ones_complement(
102 BitcastConvertType(ConvertElementType(input, F32), S32)),
103 ConstantR0<int32_t>(builder, kLow16BitsMask));
104 XlaOp input_and_iota = Xor(input_f32_trimmed, iota_s32);
105
106 // Sort in reverse order so the largest elements are at the beginning.
107 // Breaking ties here is why the index bits need to be inverted.
108 XlaOp sort_result_raw = Sort(
109 {input_and_iota}, CreateScalarGtComputation({S32}, builder), last_dim,
110 /*is_stable=*/false);
111
112 // Slice off the first k values.
113 sort_result_raw =
114 Slice(sort_result_raw, start_indices, limit_indices, strides);
115 // The k in TopK is static so we shouldn't generate a dynamic dimension
116 // even if input is dynamic.
117 sort_result_raw = RemoveDynamicDimension(sort_result_raw, last_dim);
118
119 // Get the high 16 bits of each value from the sorted result and convert
120 // them back to BF16.
121 values = ConvertElementType(
122 BitcastConvertType(
123 And(sign_magnitude_to_from_ones_complement(sort_result_raw),
124 ConstantR0<int32_t>(builder, kHigh16BitsMask)),
125 F32),
126 BF16);
127
128 // Get the index values from the low 16 bits of each value and invert them
129 // again.
130 indices = And(
131 Xor(sort_result_raw, ConstantR0<int32_t>(builder, kLow16BitsMask)),
132 ConstantR0<int32_t>(builder, kLow16BitsMask));
133 } else {
134 XlaOp sort_result =
135 Sort({input, iota_s32},
136 CreateScalarGtComputation({input_shape.element_type(), S32},
137 iota_s32.builder()),
138 last_dim, /*is_stable=*/true);
139 values = Slice(GetTupleElement(sort_result, 0), start_indices,
140 limit_indices, strides);
141 // The k in TopK is static so we shouldn't generate a dynamic dimension
142 // even if input is dynamic.
143 values = RemoveDynamicDimension(values, last_dim);
144 indices = Slice(GetTupleElement(sort_result, 1), start_indices,
145 limit_indices, strides);
146 indices = RemoveDynamicDimension(indices, last_dim);
147 }
148
149 return Tuple(builder, {values, indices});
150 });
151 }
152
TopKWithPartitions(XlaOp input,int64_t k,int64_t num_partitions)153 XlaOp TopKWithPartitions(XlaOp input, int64_t k, int64_t num_partitions) {
154 XlaBuilder* const builder = input.builder();
155 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
156 TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
157 int last_dim = input_shape.dimensions_size() - 1;
158 // Calculate per partition size.
159 auto input_dims = input_shape.dimensions();
160 int64_t last_dim_size = input_shape.dimensions(last_dim);
161 const int64_t per_partition_size =
162 CeilOfRatio(last_dim_size, num_partitions);
163 // Do normal TopK when per partition size is smaller than or equal to k.
164 if (k >= per_partition_size) {
165 return TopK(input, k);
166 }
167
168 Shape iota_shape = ShapeUtil::MakeShape(S32, input_shape.dimensions());
169 XlaOp iota_s32 = Iota(builder, iota_shape, last_dim);
170 for (int64_t i = 0; i < input_shape.rank(); ++i) {
171 if (input_shape.is_dynamic_dimension(i)) {
172 // Propagate dynamic dimension from inputs to iota.
173 iota_s32 = SetDimensionSize(iota_s32, GetDimensionSize(input, i), i);
174 }
175 }
176
177 auto topk_body_fn =
178 [&](XlaOp partition, absl::Span<const XlaOp> values_and_indices,
179 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
180 auto values = values_and_indices[0];
181 auto indices = values_and_indices[1];
182 auto input = values_and_indices[2];
183 auto iota_s32 = values_and_indices[3];
184
185 // Slice value and indices for this partition.
186 XlaOp start = Mul(Add(partition, ConstantR0<int32_t>(builder, 1)),
187 ConstantR0<int32_t>(builder, per_partition_size));
188 XlaOp sliced_input =
189 DynamicSliceInMinorDims(input, {start}, {per_partition_size});
190 XlaOp sliced_indices =
191 DynamicSliceInMinorDims(iota_s32, {start}, {per_partition_size});
192 // Concat with previous results.
193 sliced_input = ConcatInDim(builder, {values, sliced_input}, last_dim);
194 sliced_indices =
195 ConcatInDim(builder, {indices, sliced_indices}, last_dim);
196 // Sort this slice
197 XlaOp sort_result =
198 Sort({sliced_input, sliced_indices},
199 CreateScalarGtComputation({input_shape.element_type(), S32},
200 sliced_indices.builder()),
201 last_dim, true);
202
203 std::vector<int64_t> start_indices(input_shape.dimensions_size(), 0);
204 std::vector<int64_t> limit_indices(input_dims.begin(), input_dims.end());
205 std::vector<int64_t> strides(input_shape.dimensions_size(), 1);
206 // Slice topk.
207 start_indices[last_dim] = 0;
208 limit_indices[last_dim] = k;
209 values = Slice(GetTupleElement(sort_result, 0), start_indices,
210 limit_indices, strides);
211 indices = Slice(GetTupleElement(sort_result, 1), start_indices,
212 limit_indices, strides);
213 return std::vector<XlaOp>{values, indices, input, iota_s32};
214 };
215
216 // Get the values and indices for the first topk so that they can
217 // be passed to the while loop.
218 std::vector<int64_t> start_indices(input_shape.dimensions_size(), 0);
219 std::vector<int64_t> limit_indices(input_dims.begin(), input_dims.end());
220 std::vector<int64_t> strides(input_shape.dimensions_size(), 1);
221 start_indices[last_dim] = 0;
222 limit_indices[last_dim] = per_partition_size;
223 // Slice value and indices for the first partition.
224 XlaOp sliced_input = Slice(input, start_indices, limit_indices, strides);
225 XlaOp sliced_indices =
226 Slice(iota_s32, start_indices, limit_indices, strides);
227 // Sort this slice
228 XlaOp sort_result =
229 Sort({sliced_input, sliced_indices},
230 CreateScalarGtComputation({input_shape.element_type(), S32},
231 sliced_indices.builder()),
232 last_dim, /*is_stable=*/true);
233
234 // Slice topk.
235 start_indices[last_dim] = 0;
236 limit_indices[last_dim] = k;
237 XlaOp values = Slice(GetTupleElement(sort_result, 0), start_indices,
238 limit_indices, strides);
239 XlaOp indices = Slice(GetTupleElement(sort_result, 1), start_indices,
240 limit_indices, strides);
241
242 // Pass the result of the first TopK to the while loop and do
243 // num_partition - 1 iterations.
244 TF_ASSIGN_OR_RETURN(auto values_and_indices,
245 ForEachIndex(num_partitions - 1, S32, topk_body_fn,
246 {values, indices, input, iota_s32},
247 "topk_with_partition", builder));
248 return Tuple(builder, {values_and_indices[0], values_and_indices[1]});
249 });
250 }
251
252 } // namespace xla
253