1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17
18 #include "absl/strings/match.h"
19 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
20 #include "tensorflow/compiler/xla/client/lib/approx_topk.h"
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/client/xla_computation.h"
23 #include "tensorflow/compiler/xla/literal_util.h"
24 #include "tensorflow/compiler/xla/shape.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/op_requires.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/tpu/tpu_defs.h"
32 namespace tensorflow {
33 namespace {
34
ComparatorBuilder(xla::XlaBuilder * builder,xla::PrimitiveType op_type,bool is_max_k)35 xla::XlaComputation ComparatorBuilder(xla::XlaBuilder* builder,
36 xla::PrimitiveType op_type,
37 bool is_max_k) {
38 auto p0 = xla::Parameter(builder, 0, xla::ShapeUtil::MakeScalarShape(op_type),
39 "v0");
40 auto p1 = xla::Parameter(builder, 1, xla::ShapeUtil::MakeScalarShape(op_type),
41 "v1");
42 xla::Parameter(builder, 2, xla::ShapeUtil::MakeScalarShape(xla::S32), "a2");
43 xla::Parameter(builder, 3, xla::ShapeUtil::MakeScalarShape(xla::S32), "a3");
44 if (is_max_k) {
45 xla::Gt(p0, p1);
46 } else {
47 xla::Lt(p0, p1);
48 }
49 return builder->BuildAndNoteError();
50 }
51
52 class ApproxTopKOpBase : public XlaOpKernel {
53 public:
ApproxTopKOpBase(OpKernelConstruction * ctx)54 explicit ApproxTopKOpBase(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
55 // k is static instead of dynamic.
56 // This is required for deriving the approximation algorithm.
57 OP_REQUIRES_OK(ctx, ctx->GetAttr("k", &k_));
58 OP_REQUIRES_OK(ctx, ctx->GetAttr("reduction_dimension", &reduction_dim_));
59 OP_REQUIRES_OK(ctx, ctx->GetAttr("recall_target", &recall_target_));
60 OP_REQUIRES_OK(ctx, ctx->GetAttr("is_max_k", &is_max_k_));
61 OP_REQUIRES_OK(ctx, ctx->GetAttr("reduction_input_size_override",
62 &reduction_input_size_override_));
63 OP_REQUIRES_OK(ctx, ctx->GetAttr("aggregate_to_topk", &aggregate_to_topk_));
64 }
65
Compile(XlaOpKernelContext * ctx)66 void Compile(XlaOpKernelContext* ctx) override {
67 xla::Shape op_shape = ctx->InputXlaShape(0).value();
68 xla::PrimitiveType op_type = op_shape.element_type();
69
70 int64_t reduction_dim = reduction_dim_;
71 if (reduction_dim < 0) {
72 // Reverse index.
73 reduction_dim += op_shape.dimensions_size();
74 }
75 auto cmp_builder = ctx->builder()->CreateSubBuilder(
76 absl::StrFormat("top_k_%s_comparator", is_max_k_ ? "gt" : "lt"));
77 xla::XlaComputation comparator =
78 ComparatorBuilder(cmp_builder.get(), op_type, is_max_k_);
79
80 xla::XlaOp init_val = xla::ConstantLiteral(
81 ctx->builder(), is_max_k_ ? xla::LiteralUtil::MinValue(op_type)
82 : xla::LiteralUtil::MaxValue(op_type));
83 xla::XlaOp init_arg = xla::ConstantR0(ctx->builder(), -1);
84 xla::XlaOp iota = xla::Iota(
85 ctx->builder(),
86 xla::ShapeUtil::MakeShapeWithType<int32_t>(op_shape.dimensions()),
87 reduction_dim);
88 xla::XlaOp output_tuple = ApproxTopKFn(
89 ctx->builder(), {ctx->Input(0), iota}, {init_val, init_arg}, k_,
90 reduction_dim, comparator, recall_target_, aggregate_to_topk_,
91 reduction_input_size_override_);
92 ctx->SetOutput(0, xla::GetTupleElement(output_tuple, 0));
93 ctx->SetOutput(1, xla::GetTupleElement(output_tuple, 1));
94 }
95
96 protected:
97 virtual xla::XlaOp ApproxTopKFn(
98 xla::XlaBuilder* builder, absl::Span<const xla::XlaOp> operands,
99 absl::Span<const xla::XlaOp> init_values, int64_t top_k,
100 int64_t reduction_dim, const xla::XlaComputation& comparator,
101 float recall_target, bool aggregate_to_topk,
102 int64_t reduction_input_size_override) const = 0;
103
104 private:
105 int64_t k_;
106 int64_t reduction_dim_;
107 float recall_target_;
108 bool is_max_k_;
109 int64_t reduction_input_size_override_;
110 bool aggregate_to_topk_;
111
112 TF_DISALLOW_COPY_AND_ASSIGN(ApproxTopKOpBase);
113 };
114
115 class TpuApproxTopKOp : public ApproxTopKOpBase {
116 public:
TpuApproxTopKOp(OpKernelConstruction * ctx)117 explicit TpuApproxTopKOp(OpKernelConstruction* ctx) : ApproxTopKOpBase(ctx) {}
118
119 protected:
ApproxTopKFn(xla::XlaBuilder * builder,absl::Span<const xla::XlaOp> operands,absl::Span<const xla::XlaOp> init_values,int64_t top_k,int64_t reduction_dim,const xla::XlaComputation & comparator,float recall_target,bool aggregate_to_topk,int64_t reduction_input_size_override) const120 xla::XlaOp ApproxTopKFn(
121 xla::XlaBuilder* builder, absl::Span<const xla::XlaOp> operands,
122 absl::Span<const xla::XlaOp> init_values, int64_t top_k,
123 int64_t reduction_dim, const xla::XlaComputation& comparator,
124 float recall_target, bool aggregate_to_topk,
125 int64_t reduction_input_size_override) const override {
126 return xla::ApproxTopK(builder, operands, init_values, top_k, reduction_dim,
127 comparator, recall_target, aggregate_to_topk,
128 reduction_input_size_override);
129 }
130 };
131
132 class FallbackApproxTopKOp : public ApproxTopKOpBase {
133 public:
FallbackApproxTopKOp(OpKernelConstruction * ctx)134 explicit FallbackApproxTopKOp(OpKernelConstruction* ctx)
135 : ApproxTopKOpBase(ctx) {}
136
137 protected:
ApproxTopKFn(xla::XlaBuilder * builder,absl::Span<const xla::XlaOp> operands,absl::Span<const xla::XlaOp> init_values,int64_t top_k,int64_t reduction_dim,const xla::XlaComputation & comparator,float recall_target,bool aggregate_to_topk,int64_t reduction_input_size_override) const138 xla::XlaOp ApproxTopKFn(
139 xla::XlaBuilder* builder, absl::Span<const xla::XlaOp> operands,
140 absl::Span<const xla::XlaOp> init_values, int64_t top_k,
141 int64_t reduction_dim, const xla::XlaComputation& comparator,
142 float recall_target, bool aggregate_to_topk,
143 int64_t reduction_input_size_override) const override {
144 return xla::ApproxTopKFallback(
145 builder, operands, init_values, top_k, reduction_dim, comparator,
146 recall_target, aggregate_to_topk, reduction_input_size_override);
147 }
148 };
149
150 // Register for TPU
151 REGISTER_XLA_OP(Name("ApproxTopK")
152 .Device(absl::Span<const absl::string_view>{
153 DEVICE_TPU, DEVICE_TPU_XLA_JIT})
154 .TypeConstraint("T", {DT_FLOAT, DT_HALF, DT_BFLOAT16}),
155 TpuApproxTopKOp);
156
157 // Register for all registered devices except for TPU since it is already
158 // registered.
159 REGISTER_XLA_OP(
160 Name("ApproxTopK").TypeConstraint("T", {DT_FLOAT, DT_HALF, DT_BFLOAT16}),
161 FallbackApproxTopKOp);
162
163 } // namespace
164 } // namespace tensorflow
165