• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_APPROX_TOPK_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_APPROX_TOPK_H_
18 
19 #include "tensorflow/compiler/xla/client/xla_builder.h"
20 #include "tensorflow/compiler/xla/xla_data.pb.h"
21 
22 namespace xla {
23 
24 // Computes approximate top-ks by aggregating top-1s in equal-sized windows.
25 // The number and the size of the windows are determined by the `recall_target`.
26 //
27 // operand: A sequence of multi-dimensional arrays of type T_0, ..., T_{N-1}
28 // init_values: N starting values for top-1 reductions
29 // top_k: Determines the k in top-k operation.
30 // reduction_dim: Determines the dimension to compute top-k.
31 // comparator: The comparator computation to use, which should have function
32 //   signatore of (T_0, T_0, T_1, T_1, ..., T_{N-1}, T_{N-1}) -> bool.
33 // recall_target: Valid range (0, 1]. User can trade-off quality and performance
34 //   with this knob.
35 // aggregate_to_topk: When true, sorts the set of approximate top-k elements and
36 //   only keep the final k elements on TPU. This option is useful when user
37 //   wanted to forward the approximate results to host and aggregate the results
38 //   on CPU for better throughput.
39 // reduction_input_size_override: When set to a positive value, it overrides the
40 //   size determined by operands[reduction_dim] for evaluating the recall. This
41 //   option is useful when the given operand is only a subset of the overall
42 //   computation in SPMD or distributed pipelines, where the true input size
43 //   cannot be deferred by the operand shape.
44 //
45 // Returns a sequence of multidimensional arrays of type T_0, ..., T_{N-1},
46 // which contains the approximate top-ks from the input operands. When
47 // `aggregate_to_topk` is set to true, the output size is just top_k. When
48 // `aggregate_to_topk` is set to false, the output size varied by the target
49 // recall. For target recall = 0.9, the output size is roughly 10 * top_k. For
50 // target recall = 0.99, the output size is roughly 100 * top_k.
51 //
52 // TODO(fchern): Support other hardware platforms.
53 XlaOp ApproxTopK(XlaBuilder* builder, absl::Span<const XlaOp> operands,
54                  absl::Span<const XlaOp> init_values, int64_t top_k,
55                  int64_t reduction_dim, const XlaComputation& comparator,
56                  float recall_target = 0.9, bool aggregate_to_topk = true,
57                  int64_t reduction_input_size_override = -1);
58 
59 // Fallback for platforms that haven't been optimized.
60 XlaOp ApproxTopKFallback(XlaBuilder* builder, absl::Span<const XlaOp> operands,
61                          absl::Span<const XlaOp> init_values, int64_t top_k,
62                          int64_t reduction_dim,
63                          const XlaComputation& comparator,
64                          float recall_target = 0.9,
65                          bool aggregate_to_topk = true,
66                          int64_t reduction_input_size_override = -1);
67 
68 }  // namespace xla
69 
70 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_APPROX_TOPK_H_
71