1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/type_util.h" 17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 19 #include "tensorflow/compiler/xla/client/lib/arithmetic.h" 20 #include "tensorflow/compiler/xla/client/lib/constants.h" 21 #include "tensorflow/compiler/xla/client/lib/sorting.h" 22 #include "tensorflow/compiler/xla/client/xla_builder.h" 23 #include "tensorflow/compiler/xla/literal.h" 24 #include "tensorflow/compiler/xla/xla_data.pb.h" 25 #include "tensorflow/core/framework/kernel_def_builder.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/types.h" 28 #include "tensorflow/core/platform/macros.h" 29 30 namespace tensorflow { 31 namespace { 32 33 class InTopKOp : public XlaOpKernel { 34 public: InTopKOp(OpKernelConstruction * context)35 explicit InTopKOp(OpKernelConstruction* context) : XlaOpKernel(context) { 36 OP_REQUIRES_OK(context, context->GetAttr("T", &targets_dtype_)); 37 OP_REQUIRES_OK(context, 38 DataTypeToPrimitiveType(targets_dtype_, &targets_type_)); 39 } 40 Compile(XlaOpKernelContext * context)41 void Compile(XlaOpKernelContext* context) override { 42 int64_t k; 43 OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &k)); 44 OP_REQUIRES(context, k >= 0, 45 errors::InvalidArgument("Need k >= 0, got ", k)); 46 const TensorShape predictions_shape = context->InputShape(0); 47 OP_REQUIRES( 48 context, predictions_shape.dims() == 2, 49 errors::InvalidArgument("predictions must be == 2-D, got shape ", 50 predictions_shape.DebugString())); 51 const TensorShape targets_shape = context->InputShape(1); 52 OP_REQUIRES(context, targets_shape.dims() == 1, 53 errors::InvalidArgument("targets must be == 1-D, got shape ", 54 targets_shape.DebugString())); 55 56 int64_t batch_size = predictions_shape.dim_size(0); 57 OP_REQUIRES(context, batch_size == targets_shape.dim_size(0), 58 errors::InvalidArgument( 59 "targets must have same elements as predictions rows. Had ", 60 targets_shape.dim_size(0), ", needed ", batch_size)); 61 62 // Given `predictions` with shape batch_size*num_classes and `target` with 63 // shape num_classes, we generate `targets_values_r1` with shape num_classes 64 // which the elements are the corresponding values of `targets` in 65 // `predictions` for each example. This step can be done using xla::Gather 66 // as well. 67 xla::XlaOp predictions_r2 = context->Input(0); 68 xla::XlaOp targets_r1 = context->Input(1); 69 70 xla::XlaBuilder* xla_builder = context->builder(); 71 xla::XlaOp iota_r1 = 72 xla::Iota(xla_builder, targets_type_, predictions_shape.dim_size(1)); 73 xla::XlaOp iota_r2 = xla::Broadcast(iota_r1, {batch_size}); 74 75 xla::XlaOp eq_r2 = xla::Eq(targets_r1, iota_r2, {0}); 76 xla::XlaOp zero_r0_f32 = xla::Zero(xla_builder, xla::F32); 77 xla::XlaOp zero_r2_f32 = xla::ZerosLike(predictions_r2); 78 xla::XlaOp select_r2 = xla::Select(eq_r2, predictions_r2, zero_r2_f32); 79 xla::XlaOp targets_values_r1 = xla::Reduce( 80 select_r2, zero_r0_f32, 81 xla::CreateScalarAddComputation(xla::F32, xla_builder), {1}); 82 83 // Calculate in each row of `predictions`, how many values are larger than 84 // the value of target class. Then return the result whether the count < k, 85 // which indicates the target is in topk. 86 xla::XlaOp gt_r2 = xla::Gt(predictions_r2, targets_values_r1, {0}); 87 xla::XlaOp zero_r0 = xla::Zero(xla_builder, xla::S32); 88 xla::XlaOp zero_r2 = xla::Broadcast(zero_r0, predictions_shape.dim_sizes()); 89 xla::XlaOp one_r0 = xla::One(xla_builder, xla::S32); 90 xla::XlaOp one_r2 = xla::Broadcast(one_r0, predictions_shape.dim_sizes()); 91 xla::XlaOp one_hot_r2 = xla::Select(gt_r2, one_r2, zero_r2); 92 xla::XlaOp num_gt_r1 = xla::Reduce( 93 one_hot_r2, zero_r0, 94 xla::CreateScalarAddComputation(xla::S32, xla_builder), {1}); 95 96 xla::XlaOp result = 97 xla::And(xla::Lt(num_gt_r1, xla::ConstantR0<int32>(xla_builder, k)), 98 xla::IsFinite(targets_values_r1)); 99 100 context->SetOutput(0, result); 101 } 102 103 protected: 104 DataType targets_dtype_; 105 xla::PrimitiveType targets_type_; 106 107 TF_DISALLOW_COPY_AND_ASSIGN(InTopKOp); 108 }; 109 110 REGISTER_XLA_OP(Name("InTopKV2") 111 .CompileTimeConstantInput("k") 112 .TypeConstraint("T", {DT_INT32, DT_INT64}), 113 InTopKOp); 114 115 } // namespace 116 } // namespace tensorflow 117