1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_APPROX_TOPK_H_ 17 #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_APPROX_TOPK_H_ 18 19 #include "tensorflow/compiler/xla/client/xla_builder.h" 20 #include "tensorflow/compiler/xla/xla_data.pb.h" 21 22 namespace xla { 23 24 // Computes approximate top-ks by aggregating top-1s in equal-sized windows. 25 // The number and the size of the windows are determined by the `recall_target`. 26 // 27 // operand: A sequence of multi-dimensional arrays of type T_0, ..., T_{N-1} 28 // init_values: N starting values for top-1 reductions 29 // top_k: Determines the k in top-k operation. 30 // reduction_dim: Determines the dimension to compute top-k. 31 // comparator: The comparator computation to use, which should have function 32 // signatore of (T_0, T_0, T_1, T_1, ..., T_{N-1}, T_{N-1}) -> bool. 33 // recall_target: Valid range (0, 1]. User can trade-off quality and performance 34 // with this knob. 35 // aggregate_to_topk: When true, sorts the set of approximate top-k elements and 36 // only keep the final k elements on TPU. This option is useful when user 37 // wanted to forward the approximate results to host and aggregate the results 38 // on CPU for better throughput. 39 // reduction_input_size_override: When set to a positive value, it overrides the 40 // size determined by operands[reduction_dim] for evaluating the recall. This 41 // option is useful when the given operand is only a subset of the overall 42 // computation in SPMD or distributed pipelines, where the true input size 43 // cannot be deferred by the operand shape. 44 // 45 // Returns a sequence of multidimensional arrays of type T_0, ..., T_{N-1}, 46 // which contains the approximate top-ks from the input operands. When 47 // `aggregate_to_topk` is set to true, the output size is just top_k. When 48 // `aggregate_to_topk` is set to false, the output size varied by the target 49 // recall. For target recall = 0.9, the output size is roughly 10 * top_k. For 50 // target recall = 0.99, the output size is roughly 100 * top_k. 51 // 52 // TODO(fchern): Support other hardware platforms. 53 XlaOp ApproxTopK(XlaBuilder* builder, absl::Span<const XlaOp> operands, 54 absl::Span<const XlaOp> init_values, int64_t top_k, 55 int64_t reduction_dim, const XlaComputation& comparator, 56 float recall_target = 0.9, bool aggregate_to_topk = true, 57 int64_t reduction_input_size_override = -1); 58 59 // Fallback for platforms that haven't been optimized. 60 XlaOp ApproxTopKFallback(XlaBuilder* builder, absl::Span<const XlaOp> operands, 61 absl::Span<const XlaOp> init_values, int64_t top_k, 62 int64_t reduction_dim, 63 const XlaComputation& comparator, 64 float recall_target = 0.9, 65 bool aggregate_to_topk = true, 66 int64_t reduction_input_size_override = -1); 67 68 } // namespace xla 69 70 #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_APPROX_TOPK_H_ 71