1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/sorting.h"
17
18 #include "tensorflow/compiler/xla/client/lib/comparators.h"
19 #include "tensorflow/compiler/xla/client/lib/constants.h"
20 #include "tensorflow/compiler/xla/client/lib/loops.h"
21 #include "tensorflow/compiler/xla/client/lib/slicing.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/util.h"
25
26 namespace xla {
27
TopK(XlaOp input,int64 k)28 XlaOp TopK(XlaOp input, int64 k) {
29 XlaBuilder* const builder = input.builder();
30 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
31 TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
32 int last_dim = input_shape.dimensions_size() - 1;
33 int64 last_dim_size = input_shape.dimensions(last_dim);
34 // TODO(b/148796364): tune these constants for better performance.
35 const int64 kPerPartitionSize = 8192; // 2^13
36 const int64 kLastDimSizeThreshold = 524288; // 2^19
37 const int64 kMinNumPartitions = 8;
38 const int64 kMinimalK = 1000;
39 if ((k >= kMinimalK) && (k < kPerPartitionSize) &&
40 (kPerPartitionSize / k > 2) && last_dim_size >= kLastDimSizeThreshold) {
41 int64 num_partitions =
42 CeilOfRatio(last_dim_size - k, kPerPartitionSize - k);
43 if (num_partitions >= kMinNumPartitions) {
44 return TopKWithPartitions(input, k, num_partitions);
45 }
46 }
47
48 Shape iota_shape =
49 ShapeUtil::MakeShape(S32, AsInt64Slice(input_shape.dimensions()));
50 XlaOp iota_s32 = Iota(builder, iota_shape, last_dim);
51 for (int64 i = 0; i < input_shape.rank(); ++i) {
52 if (input_shape.is_dynamic_dimension(i)) {
53 // Propagate dynamic dimension from inputs to iota.
54 iota_s32 = SetDimensionSize(iota_s32, GetDimensionSize(input, i), i);
55 }
56 }
57 auto input_dims = input_shape.dimensions();
58 XlaOp sort_result =
59 Sort({input, iota_s32},
60 CreateScalarGtComputation({input_shape.element_type(), S32},
61 iota_s32.builder()),
62 last_dim, /*is_stable=*/true);
63 std::vector<int64> start_indices(input_shape.dimensions_size(), 0);
64 std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
65 limit_indices[last_dim] = k;
66 std::vector<int64> strides(input_shape.dimensions_size(), 1);
67
68 XlaOp values = Slice(GetTupleElement(sort_result, 0), start_indices,
69 limit_indices, strides);
70 XlaOp indices = Slice(GetTupleElement(sort_result, 1), start_indices,
71 limit_indices, strides);
72 return Tuple(builder, {values, indices});
73 });
74 }
75
TopKWithPartitions(XlaOp input,int64 k,int64 num_partitions)76 XlaOp TopKWithPartitions(XlaOp input, int64 k, int64 num_partitions) {
77 XlaBuilder* const builder = input.builder();
78 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
79 TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
80 int last_dim = input_shape.dimensions_size() - 1;
81 // Calculate per partition size.
82 auto input_dims = input_shape.dimensions();
83 int64 last_dim_size = input_shape.dimensions(last_dim);
84 const int64 per_partition_size = CeilOfRatio(last_dim_size, num_partitions);
85 // Do normal TopK when per partition size is smaller than or equal to k.
86 if (k >= per_partition_size) {
87 return TopK(input, k);
88 }
89
90 Shape iota_shape =
91 ShapeUtil::MakeShape(S32, AsInt64Slice(input_shape.dimensions()));
92 XlaOp iota_s32 = Iota(builder, iota_shape, last_dim);
93 for (int64 i = 0; i < input_shape.rank(); ++i) {
94 if (input_shape.is_dynamic_dimension(i)) {
95 // Propagate dynamic dimension from inputs to iota.
96 iota_s32 = SetDimensionSize(iota_s32, GetDimensionSize(input, i), i);
97 }
98 }
99
100 auto topk_body_fn =
101 [&](XlaOp partition, absl::Span<const XlaOp> values_and_indices,
102 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
103 auto values = values_and_indices[0];
104 auto indices = values_and_indices[1];
105 auto input = values_and_indices[2];
106 auto iota_s32 = values_and_indices[3];
107
108 // Slice value and indices for this partition.
109 XlaOp start = Mul(Add(partition, ConstantR0<int32>(builder, 1)),
110 ConstantR0<int32>(builder, per_partition_size));
111 XlaOp sliced_input =
112 DynamicSliceInMinorDims(input, {start}, {per_partition_size});
113 XlaOp sliced_indices =
114 DynamicSliceInMinorDims(iota_s32, {start}, {per_partition_size});
115 // Concat with previous results.
116 sliced_input = ConcatInDim(builder, {values, sliced_input}, last_dim);
117 sliced_indices =
118 ConcatInDim(builder, {indices, sliced_indices}, last_dim);
119 // Sort this slice
120 XlaOp sort_result =
121 Sort({sliced_input, sliced_indices},
122 CreateScalarGtComputation({input_shape.element_type(), S32},
123 sliced_indices.builder()),
124 last_dim, true);
125
126 std::vector<int64> start_indices(input_shape.dimensions_size(), 0);
127 std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
128 std::vector<int64> strides(input_shape.dimensions_size(), 1);
129 // Slice topk.
130 start_indices[last_dim] = 0;
131 limit_indices[last_dim] = k;
132 values = Slice(GetTupleElement(sort_result, 0), start_indices,
133 limit_indices, strides);
134 indices = Slice(GetTupleElement(sort_result, 1), start_indices,
135 limit_indices, strides);
136 return std::vector<XlaOp>{values, indices, input, iota_s32};
137 };
138
139 // Get the values and indices for the first topk so that they can
140 // be passed to the while loop.
141 std::vector<int64> start_indices(input_shape.dimensions_size(), 0);
142 std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
143 std::vector<int64> strides(input_shape.dimensions_size(), 1);
144 start_indices[last_dim] = 0;
145 limit_indices[last_dim] = per_partition_size;
146 // Slice value and indices for the first partition.
147 XlaOp sliced_input = Slice(input, start_indices, limit_indices, strides);
148 XlaOp sliced_indices =
149 Slice(iota_s32, start_indices, limit_indices, strides);
150 // Sort this slice
151 XlaOp sort_result =
152 Sort({sliced_input, sliced_indices},
153 CreateScalarGtComputation({input_shape.element_type(), S32},
154 sliced_indices.builder()),
155 last_dim, /*is_stable=*/true);
156
157 // Slice topk.
158 start_indices[last_dim] = 0;
159 limit_indices[last_dim] = k;
160 XlaOp values = Slice(GetTupleElement(sort_result, 0), start_indices,
161 limit_indices, strides);
162 XlaOp indices = Slice(GetTupleElement(sort_result, 1), start_indices,
163 limit_indices, strides);
164
165 // Pass the result of the first TopK to the while loop and do
166 // num_partition - 1 iterations.
167 TF_ASSIGN_OR_RETURN(auto values_and_indices,
168 ForEachIndex(num_partitions - 1, S32, topk_body_fn,
169 {values, indices, input, iota_s32},
170 "topk_with_partition", builder));
171 return Tuple(builder, {values_and_indices[0], values_and_indices[1]});
172 });
173 }
174
175 } // namespace xla
176