1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // XLA implementations of Random ops 17 // TODO(misard,phawkins): handle random number generator seeds/states correctly. 18 // TODO(misard,phawkins): add tests. 19 20 #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" 21 #include "tensorflow/compiler/tf2xla/lib/random.h" 22 #include "tensorflow/compiler/tf2xla/lib/util.h" 23 #include "tensorflow/compiler/tf2xla/shape_util.h" 24 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 25 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 26 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 27 #include "tensorflow/compiler/xla/client/lib/arithmetic.h" 28 #include "tensorflow/compiler/xla/client/lib/constants.h" 29 #include "tensorflow/compiler/xla/client/lib/loops.h" 30 #include "tensorflow/compiler/xla/client/xla_builder.h" 31 #include "tensorflow/core/framework/op_kernel.h" 32 #include "tensorflow/core/framework/tensor.h" 33 #include "tensorflow/core/framework/tensor_shape.h" 34 35 namespace tensorflow { 36 namespace { 37 38 class RandomUniformOp : public XlaOpKernel { 39 public: RandomUniformOp(OpKernelConstruction * ctx)40 explicit RandomUniformOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 41 Compile(XlaOpKernelContext * ctx)42 void Compile(XlaOpKernelContext* ctx) override { 43 TensorShape shape; 44 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); 45 46 const DataType dtype = output_type(0); 47 xla::Shape xla_shape; 48 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); 49 50 xla::XlaBuilder* b = ctx->builder(); 51 xla::XlaOp result = xla::RngUniform(XlaHelpers::Zero(b, dtype), 52 XlaHelpers::One(b, dtype), xla_shape); 53 54 ctx->SetOutput(0, result); 55 } 56 57 private: 58 TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformOp); 59 }; 60 61 REGISTER_XLA_OP(Name("RandomUniform").CompileTimeConstantInput("shape"), 62 RandomUniformOp); 63 64 class RandomShuffleOp : public XlaOpKernel { 65 public: RandomShuffleOp(OpKernelConstruction * ctx)66 explicit RandomShuffleOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 67 Compile(XlaOpKernelContext * ctx)68 void Compile(XlaOpKernelContext* ctx) override { 69 auto builder = ctx->builder(); 70 xla::XlaOp input = ctx->Input(0); 71 TensorShape input_shape = ctx->InputShape(0); 72 const int64 n = input_shape.dim_size(0); 73 int64 num_elements = 1; 74 for (tensorflow::TensorShapeDim dimension : input_shape) { 75 num_elements *= dimension.size; 76 } 77 78 if (num_elements <= 1 || n <= 1) { 79 // No shuffling is required, so copy input directly to output 80 ctx->SetOutput(0, input); 81 return; 82 } 83 84 if (input_shape.dims() == 1) { 85 // For R1s, shuffle values by sorting instead of the obvious Fisher-Yates 86 // algorithm. Fisher-Yates is simple to implement and correct, but not 87 // easily parallelizable. For a sufficiently parallel architecture, it is 88 // faster to sort many times, than Fisher-Yates shuffle once. 89 90 // Shuffle values by assigning each value a random key and sorting the 91 // keys. Keys can collide causing detectable patterns in the shuffled 92 // output. Collisions translates into more ascending sub-sequences in the 93 // shuffled output than would be expected by chance. To avoid collisions, 94 // the number of possible key values must be sufficiently large. 95 96 // How are more than 2^32 keys created? In each loop iteration, the 97 // algorithm sorts by random keys. Conceptually, the earlier iterations 98 // are sorting on the lower-order bits of larger keys that are never 99 // actually assembled. 100 101 // The expected number of collisions is n - d + d(1 - 1/d)^n, where d is 102 // the number of possible keys and n is the number of values. If d = n^2, 103 // then the limit as n goes to infinity is 1/2. If d = n^3, then the limit 104 // as n goes to infinity is zero. 105 106 // This implementation ensures that the key-space is greater than or equal 107 // to the cube of the number of values. The risk of collisions can be 108 // further reduced by increasing Exponent at the expense of 109 // performance. 110 111 // For Exponent = 2, the expected number of collisions per shuffle is 112 // maximized at n = floor((2^32-1)^(1/2)) = 65535 where the expectation is 113 // about 1/2. 114 115 // For Exponent = 3, the expected number of collisions per shuffle is 116 // maximized at n = floor((2^32-1)^(1/3)) = 1625 where the expectation is 117 // about 1/3255. 118 119 // For Exponent = 4, the expected number of collisions per shuffle is 120 // maximized at n = floor((2^32-1)^(1/4)) = 255 where the expectation is 121 // about 1/132622. 122 constexpr int Exponent = 3; 123 const int rounds = static_cast<int>( 124 std::ceil(Exponent * std::log(num_elements) / std::log(kuint32max))); 125 126 const xla::Shape key_shape = 127 xla::ShapeUtil::MakeShape(xla::U32, {num_elements}); 128 xla::XlaOp zero = xla::ConstantR0(builder, 0U); 129 130 // Unfortunately, xla::RngUniform gives values in the half open interval 131 // rather than the closed interval, so instead of 2^32 possible keys there 132 // are only 2^32 - 1 (kuint32max). 133 xla::XlaOp max_value = xla::ConstantR0(builder, kuint32max); 134 135 xla::XlaOp curr = input; 136 for (int i = 0; i < rounds; ++i) { 137 xla::XlaOp keys = xla::RngUniform(zero, max_value, key_shape); 138 xla::XlaOp sorted = xla::Sort(keys, {curr}); 139 curr = xla::GetTupleElement(sorted, 1); 140 } 141 142 ctx->SetOutput(0, curr); 143 return; 144 } 145 146 // The Fisher-Yates algorithm. 147 148 // Generate the random swaps for the indices. 149 auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n}); 150 auto swaps = 151 xla::RngUniform(xla::ConstantR0<int32>(builder, 0), 152 xla::ConstantR0<int32>(builder, n), swaps_shape); 153 154 // Generate range(n) as the initial value for the indices to be swapped. 155 xla::XlaOp indices = xla::Iota(builder, xla::S32, n); 156 157 // Swap the indices at i and swaps[i]. 158 auto swap_body_fn = [&](xla::XlaOp i, 159 absl::Span<const xla::XlaOp> loop_vars, 160 xla::XlaBuilder* builder) 161 -> xla::StatusOr<std::vector<xla::XlaOp>> { 162 auto swaps = loop_vars[0]; 163 auto indices = loop_vars[1]; 164 // TODO(b/118437727): The absl::Span nonsense is only necessary because 165 // the deprecated overload creates ambiguity for the single-element span 166 // case. Remove it once the deprecated overload is gone. 167 // temp = indices[i] 168 auto temp = 169 xla::DynamicSlice(indices, absl::Span<const xla::XlaOp>({i}), {1}); 170 // swap_index = swaps[i] 171 auto swap_index = xla::Reshape( 172 xla::DynamicSlice(swaps, absl::Span<const xla::XlaOp>({i}), {1}), {}); 173 // swap_value = indices[swaps[i]] 174 auto swap_value = xla::DynamicSlice( 175 indices, absl::Span<const xla::XlaOp>({swap_index}), {1}); 176 // indices[i] = indices[swaps[i]] 177 indices = xla::DynamicUpdateSlice(indices, swap_value, 178 absl::Span<const xla::XlaOp>({i})); 179 // indices[swaps[i]] = temp 180 indices = xla::DynamicUpdateSlice( 181 indices, temp, absl::Span<const xla::XlaOp>({swap_index})); 182 return std::vector<xla::XlaOp>{swaps, indices}; 183 }; 184 // for i in range(n): 185 auto swap_loop_result = 186 xla::ForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices}, 187 "indices_swap_loop", builder) 188 .ValueOrDie(); 189 auto swapped_indices = swap_loop_result[1]; 190 191 // Gather the data using the swapped indices as the shuffled order. 192 auto indices_tensor_shape = TensorShape({n}); 193 DataType type = ctx->expected_output_dtype(0); 194 xla::XlaOp gather; 195 OP_REQUIRES_OK(ctx, XlaGather(input, input_shape, swapped_indices, 196 indices_tensor_shape, 197 /*axis=*/0, /*indices_are_nd=*/false, type, 198 DT_INT32, builder, &gather)); 199 ctx->SetOutput(0, gather); 200 } 201 202 private: 203 TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleOp); 204 }; 205 206 REGISTER_XLA_OP(Name("RandomShuffle"), RandomShuffleOp); 207 208 class RandomUniformIntOp : public XlaOpKernel { 209 public: RandomUniformIntOp(OpKernelConstruction * ctx)210 explicit RandomUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 211 Compile(XlaOpKernelContext * ctx)212 void Compile(XlaOpKernelContext* ctx) override { 213 TensorShape shape; 214 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); 215 xla::Shape xla_shape; 216 OP_REQUIRES_OK(ctx, 217 TensorShapeToXLAShape(input_type(1), shape, &xla_shape)); 218 219 const TensorShape minval_shape = ctx->InputShape(1); 220 const TensorShape maxval_shape = ctx->InputShape(2); 221 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval_shape), 222 errors::InvalidArgument("minval must be 0-D, got shape ", 223 minval_shape.DebugString())); 224 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval_shape), 225 errors::InvalidArgument("maxval must be 0-D, got shape ", 226 maxval_shape.DebugString())); 227 228 auto minval = ctx->Input(1); 229 auto maxval = ctx->Input(2); 230 ctx->SetOutput(0, xla::RngUniform(minval, maxval, xla_shape)); 231 } 232 233 private: 234 TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformIntOp); 235 }; 236 237 REGISTER_XLA_OP(Name("RandomUniformInt").CompileTimeConstantInput("shape"), 238 RandomUniformIntOp); 239 240 class RandomStandardNormalOp : public XlaOpKernel { 241 public: RandomStandardNormalOp(OpKernelConstruction * ctx)242 explicit RandomStandardNormalOp(OpKernelConstruction* ctx) 243 : XlaOpKernel(ctx) {} 244 Compile(XlaOpKernelContext * ctx)245 void Compile(XlaOpKernelContext* ctx) override { 246 const DataType dtype = output_type(0); 247 248 TensorShape shape; 249 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); 250 xla::Shape xla_shape; 251 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); 252 253 xla::XlaBuilder* b = ctx->builder(); 254 255 // Normal distribution with a mean of 0 and a standard deviation of 1: 256 xla::XlaOp result = xla::RngNormal(XlaHelpers::Zero(b, dtype), 257 XlaHelpers::One(b, dtype), xla_shape); 258 259 ctx->SetOutput(0, result); 260 } 261 262 private: 263 TF_DISALLOW_COPY_AND_ASSIGN(RandomStandardNormalOp); 264 }; 265 266 REGISTER_XLA_OP(Name("RandomStandardNormal").CompileTimeConstantInput("shape"), 267 RandomStandardNormalOp); 268 269 class TruncatedNormalOp : public XlaOpKernel { 270 public: TruncatedNormalOp(OpKernelConstruction * ctx)271 explicit TruncatedNormalOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 272 Compile(XlaOpKernelContext * ctx)273 void Compile(XlaOpKernelContext* ctx) override { 274 const DataType dtype = output_type(0); 275 276 TensorShape shape; 277 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); 278 xla::Shape xla_shape; 279 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); 280 281 xla::XlaBuilder* b = ctx->builder(); 282 283 xla::XlaOp one = xla::One(b, xla_shape.element_type()); 284 xla::XlaOp min_positive = 285 xla::MinPositiveNormalValue(b, xla_shape.element_type()); 286 auto uniform = xla::RngUniform(min_positive, one, xla_shape); 287 ctx->SetOutput(0, TruncatedNormal(uniform)); 288 } 289 }; 290 291 REGISTER_XLA_OP(Name("TruncatedNormal") 292 .CompileTimeConstantInput("shape") 293 .TypeConstraint("dtype", DT_FLOAT), 294 TruncatedNormalOp); 295 296 } // namespace 297 } // namespace tensorflow 298