• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 
18 #include "absl/strings/match.h"
19 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
20 #include "tensorflow/compiler/xla/client/lib/approx_topk.h"
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/client/xla_computation.h"
23 #include "tensorflow/compiler/xla/literal_util.h"
24 #include "tensorflow/compiler/xla/shape.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/op_requires.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/tpu/tpu_defs.h"
32 namespace tensorflow {
33 namespace {
34 
ComparatorBuilder(xla::XlaBuilder * builder,xla::PrimitiveType op_type,bool is_max_k)35 xla::XlaComputation ComparatorBuilder(xla::XlaBuilder* builder,
36                                       xla::PrimitiveType op_type,
37                                       bool is_max_k) {
38   auto p0 = xla::Parameter(builder, 0, xla::ShapeUtil::MakeScalarShape(op_type),
39                            "v0");
40   auto p1 = xla::Parameter(builder, 1, xla::ShapeUtil::MakeScalarShape(op_type),
41                            "v1");
42   xla::Parameter(builder, 2, xla::ShapeUtil::MakeScalarShape(xla::S32), "a2");
43   xla::Parameter(builder, 3, xla::ShapeUtil::MakeScalarShape(xla::S32), "a3");
44   if (is_max_k) {
45     xla::Gt(p0, p1);
46   } else {
47     xla::Lt(p0, p1);
48   }
49   return builder->BuildAndNoteError();
50 }
51 
52 class ApproxTopKOpBase : public XlaOpKernel {
53  public:
ApproxTopKOpBase(OpKernelConstruction * ctx)54   explicit ApproxTopKOpBase(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
55     // k is static instead of dynamic.
56     // This is required for deriving the approximation algorithm.
57     OP_REQUIRES_OK(ctx, ctx->GetAttr("k", &k_));
58     OP_REQUIRES_OK(ctx, ctx->GetAttr("reduction_dimension", &reduction_dim_));
59     OP_REQUIRES_OK(ctx, ctx->GetAttr("recall_target", &recall_target_));
60     OP_REQUIRES_OK(ctx, ctx->GetAttr("is_max_k", &is_max_k_));
61     OP_REQUIRES_OK(ctx, ctx->GetAttr("reduction_input_size_override",
62                                      &reduction_input_size_override_));
63     OP_REQUIRES_OK(ctx, ctx->GetAttr("aggregate_to_topk", &aggregate_to_topk_));
64   }
65 
Compile(XlaOpKernelContext * ctx)66   void Compile(XlaOpKernelContext* ctx) override {
67     xla::Shape op_shape = ctx->InputXlaShape(0).value();
68     xla::PrimitiveType op_type = op_shape.element_type();
69 
70     int64_t reduction_dim = reduction_dim_;
71     if (reduction_dim < 0) {
72       // Reverse index.
73       reduction_dim += op_shape.dimensions_size();
74     }
75     auto cmp_builder = ctx->builder()->CreateSubBuilder(
76         absl::StrFormat("top_k_%s_comparator", is_max_k_ ? "gt" : "lt"));
77     xla::XlaComputation comparator =
78         ComparatorBuilder(cmp_builder.get(), op_type, is_max_k_);
79 
80     xla::XlaOp init_val = xla::ConstantLiteral(
81         ctx->builder(), is_max_k_ ? xla::LiteralUtil::MinValue(op_type)
82                                   : xla::LiteralUtil::MaxValue(op_type));
83     xla::XlaOp init_arg = xla::ConstantR0(ctx->builder(), -1);
84     xla::XlaOp iota = xla::Iota(
85         ctx->builder(),
86         xla::ShapeUtil::MakeShapeWithType<int32_t>(op_shape.dimensions()),
87         reduction_dim);
88     xla::XlaOp output_tuple = ApproxTopKFn(
89         ctx->builder(), {ctx->Input(0), iota}, {init_val, init_arg}, k_,
90         reduction_dim, comparator, recall_target_, aggregate_to_topk_,
91         reduction_input_size_override_);
92     ctx->SetOutput(0, xla::GetTupleElement(output_tuple, 0));
93     ctx->SetOutput(1, xla::GetTupleElement(output_tuple, 1));
94   }
95 
96  protected:
97   virtual xla::XlaOp ApproxTopKFn(
98       xla::XlaBuilder* builder, absl::Span<const xla::XlaOp> operands,
99       absl::Span<const xla::XlaOp> init_values, int64_t top_k,
100       int64_t reduction_dim, const xla::XlaComputation& comparator,
101       float recall_target, bool aggregate_to_topk,
102       int64_t reduction_input_size_override) const = 0;
103 
104  private:
105   int64_t k_;
106   int64_t reduction_dim_;
107   float recall_target_;
108   bool is_max_k_;
109   int64_t reduction_input_size_override_;
110   bool aggregate_to_topk_;
111 
112   TF_DISALLOW_COPY_AND_ASSIGN(ApproxTopKOpBase);
113 };
114 
115 class TpuApproxTopKOp : public ApproxTopKOpBase {
116  public:
TpuApproxTopKOp(OpKernelConstruction * ctx)117   explicit TpuApproxTopKOp(OpKernelConstruction* ctx) : ApproxTopKOpBase(ctx) {}
118 
119  protected:
ApproxTopKFn(xla::XlaBuilder * builder,absl::Span<const xla::XlaOp> operands,absl::Span<const xla::XlaOp> init_values,int64_t top_k,int64_t reduction_dim,const xla::XlaComputation & comparator,float recall_target,bool aggregate_to_topk,int64_t reduction_input_size_override) const120   xla::XlaOp ApproxTopKFn(
121       xla::XlaBuilder* builder, absl::Span<const xla::XlaOp> operands,
122       absl::Span<const xla::XlaOp> init_values, int64_t top_k,
123       int64_t reduction_dim, const xla::XlaComputation& comparator,
124       float recall_target, bool aggregate_to_topk,
125       int64_t reduction_input_size_override) const override {
126     return xla::ApproxTopK(builder, operands, init_values, top_k, reduction_dim,
127                            comparator, recall_target, aggregate_to_topk,
128                            reduction_input_size_override);
129   }
130 };
131 
132 class FallbackApproxTopKOp : public ApproxTopKOpBase {
133  public:
FallbackApproxTopKOp(OpKernelConstruction * ctx)134   explicit FallbackApproxTopKOp(OpKernelConstruction* ctx)
135       : ApproxTopKOpBase(ctx) {}
136 
137  protected:
ApproxTopKFn(xla::XlaBuilder * builder,absl::Span<const xla::XlaOp> operands,absl::Span<const xla::XlaOp> init_values,int64_t top_k,int64_t reduction_dim,const xla::XlaComputation & comparator,float recall_target,bool aggregate_to_topk,int64_t reduction_input_size_override) const138   xla::XlaOp ApproxTopKFn(
139       xla::XlaBuilder* builder, absl::Span<const xla::XlaOp> operands,
140       absl::Span<const xla::XlaOp> init_values, int64_t top_k,
141       int64_t reduction_dim, const xla::XlaComputation& comparator,
142       float recall_target, bool aggregate_to_topk,
143       int64_t reduction_input_size_override) const override {
144     return xla::ApproxTopKFallback(
145         builder, operands, init_values, top_k, reduction_dim, comparator,
146         recall_target, aggregate_to_topk, reduction_input_size_override);
147   }
148 };
149 
150 // Register for TPU
151 REGISTER_XLA_OP(Name("ApproxTopK")
152                     .Device(absl::Span<const absl::string_view>{
153                         DEVICE_TPU, DEVICE_TPU_XLA_JIT})
154                     .TypeConstraint("T", {DT_FLOAT, DT_HALF, DT_BFLOAT16}),
155                 TpuApproxTopKOp);
156 
157 // Register for all registered devices except for TPU since it is already
158 // registered.
159 REGISTER_XLA_OP(
160     Name("ApproxTopK").TypeConstraint("T", {DT_FLOAT, DT_HALF, DT_BFLOAT16}),
161     FallbackApproxTopKOp);
162 
163 }  // namespace
164 }  // namespace tensorflow
165