1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/sorting.h"
17
18 #include "tensorflow/compiler/xla/client/lib/comparators.h"
19 #include "tensorflow/compiler/xla/client/lib/constants.h"
20 #include "tensorflow/compiler/xla/client/lib/loops.h"
21 #include "tensorflow/compiler/xla/client/lib/slicing.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/util.h"
25
26 namespace xla {
27
TopK(XlaOp input,int64_t k)28 XlaOp TopK(XlaOp input, int64_t k) {
29 XlaBuilder* const builder = input.builder();
30 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
31 TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
32 int last_dim = input_shape.dimensions_size() - 1;
33 int64_t last_dim_size = input_shape.dimensions(last_dim);
34 // TODO(b/148796364): tune these constants for better performance.
35 const int64_t kPerPartitionSize = 8192; // 2^13
36 const int64_t kLastDimSizeThreshold = 524288; // 2^19
37 const int64_t kMinNumPartitions = 8;
38 const int64_t kMinimalK = 1000;
39 if ((k >= kMinimalK) && (k < kPerPartitionSize) &&
40 (kPerPartitionSize / k > 2) && last_dim_size >= kLastDimSizeThreshold) {
41 int64_t num_partitions =
42 CeilOfRatio(last_dim_size - k, kPerPartitionSize - k);
43 if (num_partitions >= kMinNumPartitions) {
44 return TopKWithPartitions(input, k, num_partitions);
45 }
46 }
47
48 Shape iota_shape =
49 ShapeUtil::MakeShape(S32, AsInt64Slice(input_shape.dimensions()));
50 XlaOp iota_s32 = Iota(builder, iota_shape, last_dim);
51 for (int64_t i = 0; i < input_shape.rank(); ++i) {
52 if (input_shape.is_dynamic_dimension(i)) {
53 // Propagate dynamic dimension from inputs to iota.
54 iota_s32 = SetDimensionSize(iota_s32, GetDimensionSize(input, i), i);
55 }
56 }
57 auto input_dims = input_shape.dimensions();
58 XlaOp sort_result =
59 Sort({input, iota_s32},
60 CreateScalarGtComputation({input_shape.element_type(), S32},
61 iota_s32.builder()),
62 last_dim, /*is_stable=*/true);
63 std::vector<int64> start_indices(input_shape.dimensions_size(), 0);
64 std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
65 limit_indices[last_dim] = k;
66 std::vector<int64> strides(input_shape.dimensions_size(), 1);
67
68 XlaOp values = Slice(GetTupleElement(sort_result, 0), start_indices,
69 limit_indices, strides);
70 // The k in TopK is static so we shouldn't generate a dynamic dimension even
71 // if input is dynamic.
72 values = RemoveDynamicDimension(values, last_dim);
73 XlaOp indices = Slice(GetTupleElement(sort_result, 1), start_indices,
74 limit_indices, strides);
75 indices = RemoveDynamicDimension(indices, last_dim);
76 return Tuple(builder, {values, indices});
77 });
78 }
79
TopKWithPartitions(XlaOp input,int64_t k,int64_t num_partitions)80 XlaOp TopKWithPartitions(XlaOp input, int64_t k, int64_t num_partitions) {
81 XlaBuilder* const builder = input.builder();
82 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
83 TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
84 int last_dim = input_shape.dimensions_size() - 1;
85 // Calculate per partition size.
86 auto input_dims = input_shape.dimensions();
87 int64_t last_dim_size = input_shape.dimensions(last_dim);
88 const int64_t per_partition_size =
89 CeilOfRatio(last_dim_size, num_partitions);
90 // Do normal TopK when per partition size is smaller than or equal to k.
91 if (k >= per_partition_size) {
92 return TopK(input, k);
93 }
94
95 Shape iota_shape =
96 ShapeUtil::MakeShape(S32, AsInt64Slice(input_shape.dimensions()));
97 XlaOp iota_s32 = Iota(builder, iota_shape, last_dim);
98 for (int64_t i = 0; i < input_shape.rank(); ++i) {
99 if (input_shape.is_dynamic_dimension(i)) {
100 // Propagate dynamic dimension from inputs to iota.
101 iota_s32 = SetDimensionSize(iota_s32, GetDimensionSize(input, i), i);
102 }
103 }
104
105 auto topk_body_fn =
106 [&](XlaOp partition, absl::Span<const XlaOp> values_and_indices,
107 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
108 auto values = values_and_indices[0];
109 auto indices = values_and_indices[1];
110 auto input = values_and_indices[2];
111 auto iota_s32 = values_and_indices[3];
112
113 // Slice value and indices for this partition.
114 XlaOp start = Mul(Add(partition, ConstantR0<int32>(builder, 1)),
115 ConstantR0<int32>(builder, per_partition_size));
116 XlaOp sliced_input =
117 DynamicSliceInMinorDims(input, {start}, {per_partition_size});
118 XlaOp sliced_indices =
119 DynamicSliceInMinorDims(iota_s32, {start}, {per_partition_size});
120 // Concat with previous results.
121 sliced_input = ConcatInDim(builder, {values, sliced_input}, last_dim);
122 sliced_indices =
123 ConcatInDim(builder, {indices, sliced_indices}, last_dim);
124 // Sort this slice
125 XlaOp sort_result =
126 Sort({sliced_input, sliced_indices},
127 CreateScalarGtComputation({input_shape.element_type(), S32},
128 sliced_indices.builder()),
129 last_dim, true);
130
131 std::vector<int64> start_indices(input_shape.dimensions_size(), 0);
132 std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
133 std::vector<int64> strides(input_shape.dimensions_size(), 1);
134 // Slice topk.
135 start_indices[last_dim] = 0;
136 limit_indices[last_dim] = k;
137 values = Slice(GetTupleElement(sort_result, 0), start_indices,
138 limit_indices, strides);
139 indices = Slice(GetTupleElement(sort_result, 1), start_indices,
140 limit_indices, strides);
141 return std::vector<XlaOp>{values, indices, input, iota_s32};
142 };
143
144 // Get the values and indices for the first topk so that they can
145 // be passed to the while loop.
146 std::vector<int64> start_indices(input_shape.dimensions_size(), 0);
147 std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
148 std::vector<int64> strides(input_shape.dimensions_size(), 1);
149 start_indices[last_dim] = 0;
150 limit_indices[last_dim] = per_partition_size;
151 // Slice value and indices for the first partition.
152 XlaOp sliced_input = Slice(input, start_indices, limit_indices, strides);
153 XlaOp sliced_indices =
154 Slice(iota_s32, start_indices, limit_indices, strides);
155 // Sort this slice
156 XlaOp sort_result =
157 Sort({sliced_input, sliced_indices},
158 CreateScalarGtComputation({input_shape.element_type(), S32},
159 sliced_indices.builder()),
160 last_dim, /*is_stable=*/true);
161
162 // Slice topk.
163 start_indices[last_dim] = 0;
164 limit_indices[last_dim] = k;
165 XlaOp values = Slice(GetTupleElement(sort_result, 0), start_indices,
166 limit_indices, strides);
167 XlaOp indices = Slice(GetTupleElement(sort_result, 1), start_indices,
168 limit_indices, strides);
169
170 // Pass the result of the first TopK to the while loop and do
171 // num_partition - 1 iterations.
172 TF_ASSIGN_OR_RETURN(auto values_and_indices,
173 ForEachIndex(num_partitions - 1, S32, topk_body_fn,
174 {values, indices, input, iota_s32},
175 "topk_with_partition", builder));
176 return Tuple(builder, {values_and_indices[0], values_and_indices[1]});
177 });
178 }
179
180 } // namespace xla
181