• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/sorting.h"
17 
18 #include "tensorflow/compiler/xla/client/lib/comparators.h"
19 #include "tensorflow/compiler/xla/client/lib/constants.h"
20 #include "tensorflow/compiler/xla/client/lib/loops.h"
21 #include "tensorflow/compiler/xla/client/lib/slicing.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/util.h"
25 
26 namespace xla {
27 
TopK(XlaOp input,int64_t k)28 XlaOp TopK(XlaOp input, int64_t k) {
29   XlaBuilder* const builder = input.builder();
30   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
31     TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
32     int last_dim = input_shape.dimensions_size() - 1;
33     int64_t last_dim_size = input_shape.dimensions(last_dim);
34     // TODO(b/148796364): tune these constants for better performance.
35     const int64_t kPerPartitionSize = 8192;        // 2^13
36     const int64_t kLastDimSizeThreshold = 524288;  // 2^19
37     const int64_t kMinNumPartitions = 8;
38     const int64_t kMinimalK = 1000;
39     if ((k >= kMinimalK) && (k < kPerPartitionSize) &&
40         (kPerPartitionSize / k > 2) && last_dim_size >= kLastDimSizeThreshold) {
41       int64_t num_partitions =
42           CeilOfRatio(last_dim_size - k, kPerPartitionSize - k);
43       if (num_partitions >= kMinNumPartitions) {
44         return TopKWithPartitions(input, k, num_partitions);
45       }
46     }
47 
48     Shape iota_shape =
49         ShapeUtil::MakeShape(S32, AsInt64Slice(input_shape.dimensions()));
50     XlaOp iota_s32 = Iota(builder, iota_shape, last_dim);
51     for (int64_t i = 0; i < input_shape.rank(); ++i) {
52       if (input_shape.is_dynamic_dimension(i)) {
53         // Propagate dynamic dimension from inputs to iota.
54         iota_s32 = SetDimensionSize(iota_s32, GetDimensionSize(input, i), i);
55       }
56     }
57     auto input_dims = input_shape.dimensions();
58     XlaOp sort_result =
59         Sort({input, iota_s32},
60              CreateScalarGtComputation({input_shape.element_type(), S32},
61                                        iota_s32.builder()),
62              last_dim, /*is_stable=*/true);
63     std::vector<int64> start_indices(input_shape.dimensions_size(), 0);
64     std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
65     limit_indices[last_dim] = k;
66     std::vector<int64> strides(input_shape.dimensions_size(), 1);
67 
68     XlaOp values = Slice(GetTupleElement(sort_result, 0), start_indices,
69                          limit_indices, strides);
70     // The k in TopK is static so we shouldn't generate a dynamic dimension even
71     // if input is dynamic.
72     values = RemoveDynamicDimension(values, last_dim);
73     XlaOp indices = Slice(GetTupleElement(sort_result, 1), start_indices,
74                           limit_indices, strides);
75     indices = RemoveDynamicDimension(indices, last_dim);
76     return Tuple(builder, {values, indices});
77   });
78 }
79 
TopKWithPartitions(XlaOp input,int64_t k,int64_t num_partitions)80 XlaOp TopKWithPartitions(XlaOp input, int64_t k, int64_t num_partitions) {
81   XlaBuilder* const builder = input.builder();
82   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
83     TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
84     int last_dim = input_shape.dimensions_size() - 1;
85     // Calculate per partition size.
86     auto input_dims = input_shape.dimensions();
87     int64_t last_dim_size = input_shape.dimensions(last_dim);
88     const int64_t per_partition_size =
89         CeilOfRatio(last_dim_size, num_partitions);
90     // Do normal TopK when per partition size is smaller than or equal to k.
91     if (k >= per_partition_size) {
92       return TopK(input, k);
93     }
94 
95     Shape iota_shape =
96         ShapeUtil::MakeShape(S32, AsInt64Slice(input_shape.dimensions()));
97     XlaOp iota_s32 = Iota(builder, iota_shape, last_dim);
98     for (int64_t i = 0; i < input_shape.rank(); ++i) {
99       if (input_shape.is_dynamic_dimension(i)) {
100         // Propagate dynamic dimension from inputs to iota.
101         iota_s32 = SetDimensionSize(iota_s32, GetDimensionSize(input, i), i);
102       }
103     }
104 
105     auto topk_body_fn =
106         [&](XlaOp partition, absl::Span<const XlaOp> values_and_indices,
107             XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
108       auto values = values_and_indices[0];
109       auto indices = values_and_indices[1];
110       auto input = values_and_indices[2];
111       auto iota_s32 = values_and_indices[3];
112 
113       // Slice value and indices for this partition.
114       XlaOp start = Mul(Add(partition, ConstantR0<int32>(builder, 1)),
115                         ConstantR0<int32>(builder, per_partition_size));
116       XlaOp sliced_input =
117           DynamicSliceInMinorDims(input, {start}, {per_partition_size});
118       XlaOp sliced_indices =
119           DynamicSliceInMinorDims(iota_s32, {start}, {per_partition_size});
120       // Concat with previous results.
121       sliced_input = ConcatInDim(builder, {values, sliced_input}, last_dim);
122       sliced_indices =
123           ConcatInDim(builder, {indices, sliced_indices}, last_dim);
124       // Sort this slice
125       XlaOp sort_result =
126           Sort({sliced_input, sliced_indices},
127                CreateScalarGtComputation({input_shape.element_type(), S32},
128                                          sliced_indices.builder()),
129                last_dim, true);
130 
131       std::vector<int64> start_indices(input_shape.dimensions_size(), 0);
132       std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
133       std::vector<int64> strides(input_shape.dimensions_size(), 1);
134       // Slice topk.
135       start_indices[last_dim] = 0;
136       limit_indices[last_dim] = k;
137       values = Slice(GetTupleElement(sort_result, 0), start_indices,
138                      limit_indices, strides);
139       indices = Slice(GetTupleElement(sort_result, 1), start_indices,
140                       limit_indices, strides);
141       return std::vector<XlaOp>{values, indices, input, iota_s32};
142     };
143 
144     // Get the values and indices for the first topk so that they can
145     // be passed to the while loop.
146     std::vector<int64> start_indices(input_shape.dimensions_size(), 0);
147     std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
148     std::vector<int64> strides(input_shape.dimensions_size(), 1);
149     start_indices[last_dim] = 0;
150     limit_indices[last_dim] = per_partition_size;
151     // Slice value and indices for the first partition.
152     XlaOp sliced_input = Slice(input, start_indices, limit_indices, strides);
153     XlaOp sliced_indices =
154         Slice(iota_s32, start_indices, limit_indices, strides);
155     // Sort this slice
156     XlaOp sort_result =
157         Sort({sliced_input, sliced_indices},
158              CreateScalarGtComputation({input_shape.element_type(), S32},
159                                        sliced_indices.builder()),
160              last_dim, /*is_stable=*/true);
161 
162     // Slice topk.
163     start_indices[last_dim] = 0;
164     limit_indices[last_dim] = k;
165     XlaOp values = Slice(GetTupleElement(sort_result, 0), start_indices,
166                          limit_indices, strides);
167     XlaOp indices = Slice(GetTupleElement(sort_result, 1), start_indices,
168                           limit_indices, strides);
169 
170     // Pass the result of the first TopK to the while loop and do
171     // num_partition - 1 iterations.
172     TF_ASSIGN_OR_RETURN(auto values_and_indices,
173                         ForEachIndex(num_partitions - 1, S32, topk_body_fn,
174                                      {values, indices, input, iota_s32},
175                                      "topk_with_partition", builder));
176     return Tuple(builder, {values_and_indices[0], values_and_indices[1]});
177   });
178 }
179 
180 }  // namespace xla
181