1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // XLA implementations of Categorical op. 17 18 #include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h" 19 #include "tensorflow/compiler/tf2xla/shape_util.h" 20 #include "tensorflow/compiler/tf2xla/type_util.h" 21 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 22 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 23 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 24 #include "tensorflow/compiler/xla/client/lib/arithmetic.h" 25 #include "tensorflow/compiler/xla/client/lib/constants.h" 26 #include "tensorflow/compiler/xla/client/lib/prng.h" 27 #include "tensorflow/compiler/xla/client/xla_builder.h" 28 #include "tensorflow/compiler/xla/xla_data.pb.h" 29 #include "tensorflow/core/framework/op_kernel.h" 30 #include "tensorflow/core/framework/tensor.h" 31 #include "tensorflow/core/framework/tensor_shape.h" 32 #include "tensorflow/core/framework/types.pb.h" 33 34 namespace tensorflow { 35 namespace { 36 37 class CategoricalOp : public XlaOpKernel { 38 public: CategoricalOp(OpKernelConstruction * ctx)39 explicit CategoricalOp(OpKernelConstruction* ctx) 40 : XlaOpKernel(ctx), 41 is_gpu_(ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT) {} 42 Compile(XlaOpKernelContext * ctx)43 void Compile(XlaOpKernelContext* ctx) override { 44 // Get the logits 45 const xla::XlaOp& logits = ctx->Input(0); 46 TensorShape logits_shape = ctx->InputShape(0); 47 int64_t num_samples; 48 OP_REQUIRES_OK(ctx, 49 ctx->ConstantInputAsIntScalar( 50 1, &num_samples, xla::ValueInferenceMode::kUpperBound)); 51 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape), 52 errors::InvalidArgument("logits should be a matrix, got shape ", 53 logits_shape.DebugString())); 54 OP_REQUIRES(ctx, num_samples >= 0, 55 errors::InvalidArgument( 56 "num_samples should be nonnegative, got ", num_samples)); 57 58 for (int i = 0; i < 2; i++) { 59 const int64_t dim = logits_shape.dim_size(i); 60 OP_REQUIRES( 61 ctx, static_cast<int>(dim) == dim, 62 errors::InvalidArgument("logits.shape = ", logits_shape.DebugString(), 63 " too large for int")); 64 } 65 66 const int64_t batch_size = logits_shape.dim_size(0); 67 const int64_t num_classes = logits_shape.dim_size(1); 68 69 xla::Shape uniform_shape; 70 int class_dimension; 71 bool num_samples_is_dynamic = false; 72 OP_REQUIRES_OK( 73 ctx, ctx->ResolveInputDynamismIntoPred(1, &num_samples_is_dynamic)); 74 if (num_samples != 1 || num_samples_is_dynamic) { 75 std::array<int64, 3> uniform_shape_array = { 76 {batch_size, num_samples, num_classes}}; 77 xla::PrimitiveType uniform_xla_type; 78 OP_REQUIRES_OK(ctx, 79 DataTypeToPrimitiveType(input_type(0), &uniform_xla_type)); 80 uniform_shape = 81 xla::ShapeUtil::MakeShape(uniform_xla_type, uniform_shape_array); 82 class_dimension = 2; 83 } else { 84 // Have a special case for when we only need one sample, because 85 // dimensions may be padded on architectures with tiled memory layouts, so 86 // if the num_classes or batch size is large then this can lead to 87 // expensive wasted memory. 88 std::array<int64, 2> uniform_shape_array = {{batch_size, num_classes}}; 89 xla::PrimitiveType uniform_xla_type; 90 OP_REQUIRES_OK(ctx, 91 DataTypeToPrimitiveType(input_type(0), &uniform_xla_type)); 92 uniform_shape = 93 xla::ShapeUtil::MakeShape(uniform_xla_type, uniform_shape_array); 94 class_dimension = 1; 95 } 96 xla::PrimitiveType type; 97 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(0), &type)); 98 xla::XlaOp log_uniforms = GetLogUniforms(uniform_shape, type, ctx); 99 100 if (num_samples_is_dynamic) { 101 // num_samples is dimension 1 in uniform_shape_array. 102 log_uniforms = xla::SetDimensionSize(log_uniforms, ctx->Input(1), 1); 103 } 104 105 // Use Gumbel softmax trick to generate categorical samples. 106 // See: 107 // https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/ 108 // TODO(b/68769470): Switch to using a cumulative sum approach. 109 auto softmax_entries = 110 xla::Sub(logits, log_uniforms, 111 /*broadcast_dimensions=*/{0, class_dimension}); 112 113 xla::PrimitiveType xla_output_type; 114 OP_REQUIRES_OK(ctx, 115 DataTypeToPrimitiveType(output_type(0), &xla_output_type)); 116 xla::XlaOp argmax; 117 if (is_gpu_) { 118 argmax = xla::ArgMaxTwoPass(softmax_entries, xla_output_type, 119 /*axis=*/class_dimension); 120 } else { 121 argmax = xla::ArgMax(softmax_entries, xla_output_type, 122 /*axis=*/class_dimension, /*stable=*/true); 123 } 124 125 if (num_samples == 1 && !num_samples_is_dynamic) { 126 argmax = xla::Reshape(argmax, {batch_size, 1}); 127 } 128 129 ctx->SetOutput(0, argmax); 130 } 131 GetLogUniforms(xla::Shape uniform_shape,xla::PrimitiveType type,XlaOpKernelContext * ctx)132 virtual xla::XlaOp GetLogUniforms(xla::Shape uniform_shape, 133 xla::PrimitiveType type, 134 XlaOpKernelContext* ctx) { 135 xla::XlaBuilder* builder = ctx->builder(); 136 LOG_FIRST_N(WARNING, 1) << "Warning: Using tf.random.categorical with XLA" 137 " compilation will ignore seeds."; 138 // We want a number in (0, 1) rather than [0, 1) or (0, 1]: 139 // * log(-log(0)) is ∞. 140 // * log(-log(1)) is -∞. 141 auto uniforms = xla::RngUniform( 142 xla::MinPositiveNormalValue(builder, type), 143 xla::One(builder, uniform_shape.element_type()), uniform_shape); 144 return xla::Log(-xla::Log(uniforms)); 145 } 146 147 private: 148 bool is_gpu_; 149 TF_DISALLOW_COPY_AND_ASSIGN(CategoricalOp); 150 }; 151 152 // TODO(b/68769717): Rename this sampler to Categorical. 153 REGISTER_XLA_OP(Name("Multinomial").CompileTimeConstantInput("num_samples"), 154 CategoricalOp); 155 156 class StatelessCategoricalOp : public CategoricalOp { 157 public: StatelessCategoricalOp(OpKernelConstruction * ctx)158 explicit StatelessCategoricalOp(OpKernelConstruction* ctx) 159 : CategoricalOp(ctx), 160 device_type_string_(ctx->device_type().type_string()) { 161 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); 162 } 163 GetLogUniforms(xla::Shape uniform_shape,xla::PrimitiveType type,XlaOpKernelContext * ctx)164 xla::XlaOp GetLogUniforms(xla::Shape uniform_shape, xla::PrimitiveType type, 165 XlaOpKernelContext* ctx) override { 166 xla::XlaOp seed = ctx->Input(2); 167 168 xla::XlaBuilder* builder = ctx->builder(); 169 if (uniform_shape.element_type() == xla::BF16) { 170 uniform_shape.set_element_type(xla::F32); 171 } 172 // We want a number in (0, 1) rather than [0, 1) or (0, 1]: 173 // * log(-log(0)) is ∞. 174 // * log(-log(1)) is -∞. 175 xla::XlaOp uniforms = StatelessRngUniform( 176 device_type_string_, seed, uniform_shape, 177 xla::MinPositiveNormalValue(builder, uniform_shape.element_type()), 178 xla::One(builder, uniform_shape.element_type())); 179 return xla::ConvertElementType(xla::Log(-xla::Log(uniforms)), type); 180 } 181 Compile(XlaOpKernelContext * ctx)182 void Compile(XlaOpKernelContext* ctx) override { 183 TensorShape seed_shape = ctx->InputShape(2); 184 OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2, 185 errors::InvalidArgument("seed must have shape [2], not ", 186 seed_shape.DebugString())); 187 CategoricalOp::Compile(ctx); 188 } 189 190 private: 191 DataType dtype_; 192 string device_type_string_; 193 194 TF_DISALLOW_COPY_AND_ASSIGN(StatelessCategoricalOp); 195 }; 196 197 REGISTER_XLA_OP(Name("StatelessMultinomial") 198 .CompileTimeConstantInput("num_samples") 199 .TypeConstraint("T", {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}) 200 .TypeConstraint("Tseed", DT_INT32), 201 StatelessCategoricalOp); 202 203 } // anonymous namespace 204 } // namespace tensorflow 205