1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/prng.h"
17
18 #include <cmath>
19 #include <vector>
20
21 #include "absl/base/casts.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/lib/math.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/util.h"
26
27 namespace xla {
28
ConcatScalars(xla::XlaBuilder * builder,absl::Span<const xla::XlaOp> scalars)29 xla::XlaOp ConcatScalars(xla::XlaBuilder* builder,
30 absl::Span<const xla::XlaOp> scalars) {
31 std::vector<xla::XlaOp> vectors;
32 absl::c_transform(scalars, std::back_inserter(vectors),
33 [](xla::XlaOp x) { return xla::Reshape(x, {1}); });
34 return ConcatInDim(builder, vectors, 0);
35 }
36
37 namespace {
38
39 // Rotates a 32-bit integer 'v' left by 'distance' bits.
RotateLeftU32(XlaOp v,int distance)40 XlaOp RotateLeftU32(XlaOp v, int distance) {
41 return (v << ConstantR0<uint32>(v.builder(), distance)) |
42 ShiftRightLogical(v, ConstantR0<uint32>(v.builder(), 32 - distance));
43 }
44
45 // The internal state of the Three Fry implementation.
46 using ThreeFry2x32State = std::array<XlaOp, 2>;
47
48 // Implements the ThreeFry counter-based PRNG algorithm.
49 // Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
50 // http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
ThreeFry2x32(ThreeFry2x32State input,ThreeFry2x32State key)51 ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
52 XlaBuilder* builder = input[0].builder();
53 key[0] = BitcastConvertType(key[0], U32);
54 key[1] = BitcastConvertType(key[1], U32);
55
56 // Rotation distances specified by the Threefry2x32 algorithm.
57 constexpr std::array<int, 8> rotations = {13, 15, 26, 6, 17, 29, 16, 24};
58 ThreeFry2x32State x;
59
60 std::array<XlaOp, 3> ks;
61 // 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm.
62 ks[2] = ConstantR0<uint32>(builder, 0x1BD11BDA);
63 for (int i = 0; i < 2; ++i) {
64 ks[i] = key[i];
65 x[i] = input[i];
66 ks[2] = ks[2] ^ key[i];
67 }
68
69 x[0] = x[0] + ks[0];
70 x[1] = x[1] + ks[1];
71
72 // Performs a single round of the Threefry2x32 algorithm, with a rotation
73 // amount 'rotation'.
74 auto round = [](ThreeFry2x32State v, int rotation) {
75 v[0] = v[0] + v[1];
76 v[1] = RotateLeftU32(v[1], rotation);
77 v[1] = v[0] ^ v[1];
78 return v;
79 };
80
81 // There are no known statistical flaws with 13 rounds of Threefry2x32.
82 // We are conservative and use 20 rounds.
83 x = round(x, rotations[0]);
84 x = round(x, rotations[1]);
85 x = round(x, rotations[2]);
86 x = round(x, rotations[3]);
87 x[0] = x[0] + ks[1];
88 x[1] = x[1] + ks[2] + ConstantR0<uint32>(builder, 1);
89
90 x = round(x, rotations[4]);
91 x = round(x, rotations[5]);
92 x = round(x, rotations[6]);
93 x = round(x, rotations[7]);
94 x[0] = x[0] + ks[2];
95 x[1] = x[1] + ks[0] + ConstantR0<uint32>(builder, 2);
96
97 x = round(x, rotations[0]);
98 x = round(x, rotations[1]);
99 x = round(x, rotations[2]);
100 x = round(x, rotations[3]);
101 x[0] = x[0] + ks[0];
102 x[1] = x[1] + ks[1] + ConstantR0<uint32>(builder, 3);
103
104 x = round(x, rotations[4]);
105 x = round(x, rotations[5]);
106 x = round(x, rotations[6]);
107 x = round(x, rotations[7]);
108 x[0] = x[0] + ks[1];
109 x[1] = x[1] + ks[2] + ConstantR0<uint32>(builder, 4);
110
111 x = round(x, rotations[0]);
112 x = round(x, rotations[1]);
113 x = round(x, rotations[2]);
114 x = round(x, rotations[3]);
115 x[0] = x[0] + ks[2];
116 x[1] = x[1] + ks[0] + ConstantR0<uint32>(builder, 5);
117
118 return x;
119 }
120
121 // Converts a uint64 to two uint32s.
Uint64ToUint32s(XlaOp u64)122 std::array<XlaOp, 2> Uint64ToUint32s(XlaOp u64) {
123 XlaBuilder* builder = u64.builder();
124 XlaOp const32 = ConstantR0WithType(builder, U64, 32);
125 XlaOp fst = ConvertElementType(u64, U32);
126 XlaOp snd = ConvertElementType(ShiftRightLogical(u64, const32), U32);
127 return {fst, snd};
128 }
129
130 // Converts two uint32s to a uint64.
Uint32sToUint64(std::array<XlaOp,2> u32s)131 XlaOp Uint32sToUint64(std::array<XlaOp, 2> u32s) {
132 XlaBuilder* builder = u32s[0].builder();
133 return ConvertElementType(u32s[0], U64) |
134 ShiftLeft(ConvertElementType(u32s[1], U64),
135 ConstantR0WithType(builder, U64, 32));
136 }
137
138 // Given the initial state and the request shape of random numbers to be
139 // generated, returns the input for the random number generator and a new state.
GetThreeFryInputsAndUpdatedState(XlaOp initial_state,const Shape & shape)140 std::pair<ThreeFry2x32State, XlaOp> GetThreeFryInputsAndUpdatedState(
141 XlaOp initial_state, const Shape& shape) {
142 XlaBuilder* builder = initial_state.builder();
143 auto u64_shape = ShapeUtil::MakeShape(U64, shape.dimensions());
144 // initial_state is an R1, so reshape it to a scalar.
145 auto input_u64 = Broadcast(Reshape(initial_state, {}), shape.dimensions());
146 int64_t trailing_dims_product = 1;
147 for (int64_t i = shape.rank() - 1; i >= 0; --i) {
148 if (shape.dimensions(i) < 2) {
149 continue;
150 }
151 input_u64 =
152 input_u64 + (Iota(builder, u64_shape, i) *
153 ConstantR0<uint64>(builder, trailing_dims_product));
154 trailing_dims_product *= shape.dimensions(i);
155 }
156 XlaOp new_state =
157 initial_state + ConstantR0<uint64>(builder, ShapeUtil::ElementsIn(shape));
158 return std::make_pair(Uint64ToUint32s(input_u64), new_state);
159 }
160
161 // Result for SplitShapeIntoHalves().
162 struct SplitShapePair {
163 Shape half_shape;
164 Shape concat_shape;
165 int64 split_dim;
166 int64 new_concat_dim;
167 };
168
169 // Split the shape on a dimension > 1 into two halves.
SplitShapeIntoHalves(const Shape & shape)170 SplitShapePair SplitShapeIntoHalves(const Shape& shape) {
171 SplitShapePair pair;
172 if (shape.rank() == 0) {
173 pair.half_shape = ShapeUtil::MakeShape(shape.element_type(), {1});
174 pair.concat_shape = ShapeUtil::MakeShape(shape.element_type(), {2});
175 pair.split_dim = 0;
176 pair.new_concat_dim = 0;
177 return pair;
178 }
179 pair.split_dim = -1;
180 for (int64_t i = 0; i < shape.rank(); ++i) {
181 if (shape.dimensions(i) % 2 == 0) {
182 pair.split_dim = i;
183 break;
184 }
185 }
186 if (pair.split_dim == -1) {
187 // No even dims. Find a dimension with maximum size.
188 for (int64_t i = 0; i < shape.rank(); ++i) {
189 if (pair.split_dim == -1 ||
190 shape.dimensions(i) > shape.dimensions(pair.split_dim)) {
191 pair.split_dim = i;
192 }
193 }
194 }
195 CHECK_GE(pair.split_dim, 0);
196 std::vector<int64> half_shape_dims;
197 std::vector<int64> concat_shape_dims;
198 for (int64_t i = 0; i < shape.rank(); ++i) {
199 if (i == pair.split_dim) {
200 // Create a new trivial dim for the later concat, which is more friendly
201 // to sharding propagation.
202 half_shape_dims.push_back(CeilOfRatio<int64>(shape.dimensions(i), 2));
203 half_shape_dims.push_back(1);
204 concat_shape_dims.push_back(half_shape_dims[i]);
205 concat_shape_dims.push_back(2);
206 } else {
207 half_shape_dims.push_back(shape.dimensions(i));
208 concat_shape_dims.push_back(shape.dimensions(i));
209 }
210 }
211 pair.new_concat_dim = pair.split_dim + 1;
212 pair.half_shape = ShapeUtil::MakeShape(shape.element_type(), half_shape_dims);
213 pair.concat_shape =
214 ShapeUtil::MakeShape(shape.element_type(), concat_shape_dims);
215 return pair;
216 }
217
218 // Combines a pair of split shapes. It works with scalar and non-scalar shapes.
CombineShapePair(absl::Span<const XlaOp> pair,const SplitShapePair & shape_pair,const Shape & original_shape)219 XlaOp CombineShapePair(absl::Span<const XlaOp> pair,
220 const SplitShapePair& shape_pair,
221 const Shape& original_shape) {
222 if (original_shape.rank() == 0) {
223 return Reshape(pair[0], {});
224 }
225 XlaBuilder* builder = pair[0].builder();
226 XlaOp result = ConcatInDim(builder, pair, shape_pair.new_concat_dim);
227 const int64_t pre_split_size =
228 original_shape.dimensions(shape_pair.split_dim);
229 std::vector<int64> reshape_dims(original_shape.dimensions().begin(),
230 original_shape.dimensions().end());
231 reshape_dims[shape_pair.split_dim] =
232 RoundUpToNearest<int64>(pre_split_size, 2);
233 result = Reshape(result, reshape_dims);
234 if (reshape_dims[shape_pair.split_dim] != pre_split_size) {
235 result = Slice(result, std::vector<int64>(original_shape.rank(), 0),
236 original_shape.dimensions(),
237 std::vector<int64>(original_shape.rank(), 1));
238 }
239 return result;
240 }
241
242 // Generates random 32bits with the given shape using the Three Fry
243 // implementation. Returns the random bits and the new state.
ThreeFryRngBit32(XlaOp key,XlaOp initial_state,const Shape & shape)244 RngOutput ThreeFryRngBit32(XlaOp key, XlaOp initial_state, const Shape& shape) {
245 auto shape_pair = SplitShapeIntoHalves(shape);
246 std::pair<ThreeFry2x32State, XlaOp> inputs_state =
247 GetThreeFryInputsAndUpdatedState(initial_state, shape_pair.half_shape);
248 ThreeFry2x32State inputs = inputs_state.first;
249 ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
250 XlaOp result = CombineShapePair(outputs, shape_pair, shape);
251 return {result, inputs_state.second};
252 }
253
254 // Generates random 64bits with the given shape using the Three Fry
255 // implementation. Returns the random bits and the new state.
ThreeFryRngBit64(XlaOp key,XlaOp initial_state,const Shape & shape)256 RngOutput ThreeFryRngBit64(XlaOp key, XlaOp initial_state, const Shape& shape) {
257 std::pair<ThreeFry2x32State, XlaOp> inputs_state =
258 GetThreeFryInputsAndUpdatedState(initial_state, shape);
259 ThreeFry2x32State inputs = inputs_state.first;
260 ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
261 XlaOp result = Uint32sToUint64(outputs);
262 return {result, inputs_state.second};
263 }
264
265 // The key of the Philox random number generator.
266 using Philox4x32Key = std::array<XlaOp, 2>;
267 // The internal state of the Philox random number generator.
268 using Philox4x32State = std::array<XlaOp, 4>;
269
270 // Computes the Philox4x32 algorithm using 10 rounds.
Philox4x32(Philox4x32State state,Philox4x32Key key)271 Philox4x32State Philox4x32(Philox4x32State state, Philox4x32Key key) {
272 // Constants specified by the Philox algorithm.
273 static const uint32 kPhiloxW32A = 0x9E3779B9;
274 static const uint32 kPhiloxW32B = 0xBB67AE85;
275 static const uint32 kPhiloxM4x32A = 0xD2511F53;
276 static const uint32 kPhiloxM4x32B = 0xCD9E8D57;
277
278 struct HighLowPair {
279 XlaOp high;
280 XlaOp low;
281 };
282
283 // Compute the high and low words from multiplying two 32-bit integers.
284 auto mul_hi_low = [](XlaOp x, uint32 k) {
285 auto product =
286 ConvertElementType(x, U64) * ConstantR0<uint64>(x.builder(), k);
287 auto low = ConvertElementType(product, U32);
288 auto high =
289 ConvertElementType(product >> ConstantR0<uint64>(x.builder(), 32), U32);
290 return HighLowPair{high, low};
291 };
292
293 // Perform a single round of the Philox algorithm.
294 auto philox_round = [&](Philox4x32State x, Philox4x32Key key) {
295 auto product0 = mul_hi_low(x[0], kPhiloxM4x32A);
296 auto product1 = mul_hi_low(x[2], kPhiloxM4x32B);
297 return Philox4x32State{product1.high ^ x[1] ^ key[0], product1.low,
298 product0.high ^ x[3] ^ key[1], product0.low};
299 };
300
301 // Update the key after a round of Philox algorithm.
302 auto raise_key = [](Philox4x32Key key) {
303 XlaBuilder* builder = key[0].builder();
304 return Philox4x32Key{key[0] + ConstantR0<uint32>(builder, kPhiloxW32A),
305 key[1] + ConstantR0<uint32>(builder, kPhiloxW32B)};
306 };
307
308 static const int kNumRounds = 10;
309 for (int round = 0; round < kNumRounds; ++round, key = raise_key(key)) {
310 state = philox_round(state, key);
311 }
312 return state;
313 }
314
315 // Scrambles the input key so that users don't need to worry about which part
316 // of the key needs to be strong.
ScramblePhiloxKey(Philox4x32Key key)317 std::pair<Philox4x32State, Philox4x32Key> ScramblePhiloxKey(Philox4x32Key key) {
318 XlaBuilder* builder = key[0].builder();
319 XlaOp key0 = ConvertElementType(key[0], U64);
320 XlaOp key1 = ConvertElementType(key[1], U64);
321
322 Philox4x32State state = {
323 ConvertElementType(key0, U32),
324 ConvertElementType(key0 >> ScalarLike(key0, 32), U32),
325 ConvertElementType(key1, U32),
326 ConvertElementType(key1 >> ScalarLike(key1, 32), U32),
327 };
328 key = {ConstantR0<uint32>(builder, 0x3ec8f720),
329 ConstantR0<uint32>(builder, 0x02461e29)};
330 state = Philox4x32(state, key);
331 XlaOp zero = ConstantR0<uint32>(builder, 0);
332 return {Philox4x32State{zero, zero, state[2], state[3]},
333 Philox4x32Key{state[0], state[1]}};
334 }
335
336 // Adds an U128 tensor with an U64 tensor. The U128 tensor is represented as two
337 // U64s with the low 64bits in the front. This routine supports explicit
338 // broadcasting of the U128 tensor, with `broadcast_sizes` representing the
339 // dimensions prepended to its shape.
Uint128AddUint64(const std::array<XlaOp,2> & u128,XlaOp u64,absl::Span<const int64> broadcast_sizes={})340 std::array<XlaOp, 2> Uint128AddUint64(
341 const std::array<XlaOp, 2>& u128, XlaOp u64,
342 absl::Span<const int64> broadcast_sizes = {}) {
343 auto u128_low = u128[0];
344 auto u128_high = u128[1];
345 XlaOp new_u128_low = u128_low + u64;
346 XlaOp one = ConstantR0<uint64>(u128[0].builder(), 1);
347 XlaOp new_u128_high = Select(Lt(new_u128_low, u128_low),
348 Broadcast(u128_high + one, broadcast_sizes),
349 Broadcast(u128_high, broadcast_sizes));
350 return {new_u128_low, new_u128_high};
351 }
352
Uint32sToUint128(const std::array<XlaOp,4> & u32s)353 std::array<XlaOp, 2> Uint32sToUint128(const std::array<XlaOp, 4>& u32s) {
354 return {Uint32sToUint64({u32s[0], u32s[1]}),
355 Uint32sToUint64({u32s[2], u32s[3]})};
356 }
357
Uint128ToUint32s(const std::array<XlaOp,2> & u128)358 std::array<XlaOp, 4> Uint128ToUint32s(const std::array<XlaOp, 2>& u128) {
359 std::array<XlaOp, 2> u128_low_32s = Uint64ToUint32s(u128[0]);
360 std::array<XlaOp, 2> u128_high_32s = Uint64ToUint32s(u128[1]);
361 return {u128_low_32s[0], u128_low_32s[1], u128_high_32s[0], u128_high_32s[1]};
362 }
363
Uint128FromOp(XlaOp op)364 std::array<XlaOp, 2> Uint128FromOp(XlaOp op) {
365 auto u128_low = xla::Reshape(xla::Slice(op, {0}, {1}, {1}), {});
366 auto u128_high = xla::Reshape(xla::Slice(op, {1}, {2}, {1}), {});
367 return {u128_low, u128_high};
368 }
369
Uint128ToOp(std::array<XlaOp,2> u128)370 XlaOp Uint128ToOp(std::array<XlaOp, 2> u128) {
371 return ConcatScalars(u128[0].builder(), {u128[0], u128[1]});
372 }
373
374 // Returns the pair (state + [0, 1, ..., n-1], state + n), which should be used
375 // as the inputs fed to `Philox4x32` and the updated state. `state` is an U128
376 // represented as 4 U32s in the order from the least significant one to the most
377 // significant one.
GetPhiloxInputsAndUpdatedState(const Philox4x32State & state,int64_t n)378 std::pair<Philox4x32State, XlaOp> GetPhiloxInputsAndUpdatedState(
379 const Philox4x32State& state, int64_t n) {
380 XlaBuilder* builder = state[0].builder();
381 XlaOp iota = Iota(builder, U64, n);
382 auto state_u128 = Uint32sToUint128(state);
383 auto inputs = Uint128ToUint32s(Uint128AddUint64(state_u128, iota, {n}));
384 XlaOp new_state =
385 Uint128ToOp(Uint128AddUint64(state_u128, ConstantR0<uint64>(builder, n)));
386 return std::make_pair(inputs, new_state);
387 }
388
389 // Generates CeilOfRatio(num_elems, 4)*4 32bit Philox random numbers, as Philox
390 // numbers are generated in the unit of 128bits.
GeneratePhiloxBits(int64_t num_elems,XlaOp initial_state,Philox4x32Key key)391 std::pair<Philox4x32State, XlaOp> GeneratePhiloxBits(int64_t num_elems,
392 XlaOp initial_state,
393 Philox4x32Key key) {
394 Philox4x32State state;
395 state = Uint128ToUint32s(Uint128FromOp(initial_state));
396 const int64_t num_vector4 = CeilOfRatio<int64>(num_elems, 4);
397 Philox4x32State inputs;
398 XlaOp new_state;
399 std::tie(inputs, new_state) =
400 GetPhiloxInputsAndUpdatedState(state, num_vector4);
401 auto outputs = Philox4x32(inputs, key);
402 return std::make_pair(outputs, new_state);
403 }
404
405 // Generates an array of primitive type U32 with the given shape containing
406 // random bits generated by the Philox algorithm. Returns the array and the new
407 // state of the random number generator.
PhiloxRngBit32(XlaOp op_key,XlaOp initial_state,const Shape & shape)408 RngOutput PhiloxRngBit32(XlaOp op_key, XlaOp initial_state,
409 const Shape& shape) {
410 XlaBuilder* builder = op_key.builder();
411 const int64_t num_elems = ShapeUtil::ElementsIn(shape);
412
413 Philox4x32Key key = Uint64ToUint32s(op_key);
414 Philox4x32State bits;
415 XlaOp new_state;
416 std::tie(bits, new_state) = GeneratePhiloxBits(num_elems, initial_state, key);
417 // Combining bits[i] in a round-robin fashion, to align with non-XLA
418 // implementations
419 int64_t bits_len = (num_elems + 3) / 4;
420 for (auto i = 0; i < 4; ++i) {
421 bits[i] = Reshape(bits[i], {bits_len, 1});
422 }
423 XlaOp numbers = ConcatInDim(builder, {bits[0], bits[1], bits[2], bits[3]},
424 /*dimension=*/1);
425 numbers = Reshape(numbers, {bits_len * 4});
426 numbers = Slice(numbers, /*start_indices=*/{0},
427 /*limit_indices=*/{num_elems},
428 /*strides=*/{1});
429 return {Reshape(numbers, AsInt64Slice(shape.dimensions())), new_state};
430 }
431
432 // Generates an array of primitive type U64 with the given shape containing
433 // random bits generated by the Philox algorithm. Returns the array and the new
434 // state of the random number generator.
PhiloxRngBit64(XlaOp op_key,XlaOp initial_state,const Shape & shape)435 RngOutput PhiloxRngBit64(XlaOp op_key, XlaOp initial_state,
436 const Shape& shape) {
437 XlaBuilder* builder = op_key.builder();
438 const int64_t num_elems = ShapeUtil::ElementsIn(shape);
439
440 Philox4x32Key key = Uint64ToUint32s(op_key);
441 Philox4x32State bits32;
442 XlaOp new_state;
443 std::tie(bits32, new_state) =
444 GeneratePhiloxBits(num_elems * 2, initial_state, key);
445
446 std::array<XlaOp, 2> bits64;
447 bits64[0] = Uint32sToUint64({bits32[0], bits32[1]});
448 bits64[1] = Uint32sToUint64({bits32[2], bits32[3]});
449
450 // Combining bits64[i] in a round-robin fashion, to align with non-XLA
451 // implementations
452 int64_t bits64_len = (num_elems + 1) / 2;
453 for (auto i = 0; i < 2; ++i) {
454 bits64[i] = Reshape(bits64[i], {bits64_len, 1});
455 }
456 XlaOp numbers = ConcatInDim(builder, {bits64[0], bits64[1]},
457 /*dimension=*/1);
458 numbers = Reshape(numbers, {bits64_len * 2});
459 numbers = Slice(numbers, /*start_indices=*/{0},
460 /*limit_indices=*/{num_elems},
461 /*strides=*/{1});
462 return {Reshape(numbers, AsInt64Slice(shape.dimensions())), new_state};
463 }
464
ConvertRandomBitsToUniformFloatingPoint(XlaOp bits,XlaOp minval,XlaOp maxval)465 XlaOp ConvertRandomBitsToUniformFloatingPoint(XlaOp bits, XlaOp minval,
466 XlaOp maxval) {
467 XlaBuilder* builder = bits.builder();
468 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
469 TF_ASSIGN_OR_RETURN(const Shape* minval_shape,
470 builder->GetShapePtr(minval));
471 TF_ASSIGN_OR_RETURN(const Shape* bits_shape, builder->GetShapePtr(bits));
472 PrimitiveType value_type = minval_shape->element_type();
473 PrimitiveType bit_type = bits_shape->element_type();
474 CHECK((value_type == F32 && bit_type == U32) ||
475 (value_type == F64 && bit_type == U64));
476
477 // Form random mantissa bits for float/double, with a leading 1 bit.
478 int num_float_bits = primitive_util::BitWidth(value_type);
479 // Subtract one as SignificandWidth includes the leading 1 bit.
480 int num_mantissa_bits = primitive_util::SignificandWidth(value_type) - 1;
481
482 // Ignore the exponent bits and convert the mantissa bits to the floating
483 // point type.
484 bits = ShiftRightLogical(
485 bits, ScalarLike(bits, num_float_bits - num_mantissa_bits));
486
487 // We have an integer-valued floating point number in the range
488 // [0, 2**{num_mantissa_bits}).
489 XlaOp values = ConvertElementType(bits, value_type);
490
491 // Divide by 2**{-num_mantissa_bits} to get a number in the range
492 // [0.0, 1.0).
493 values = values * ScalarLike(values, std::ldexp(1., -num_mantissa_bits));
494
495 // Multiply and add to shift to the range [minval, maxval).
496 return values * (maxval - minval) + minval;
497 });
498 }
499
ConvertRandomBitsToUniformInt(XlaOp bits,XlaOp minval,XlaOp maxval,PrimitiveType type,PrimitiveType unsigned_type)500 XlaOp ConvertRandomBitsToUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
501 PrimitiveType type,
502 PrimitiveType unsigned_type) {
503 XlaBuilder* builder = bits.builder();
504 XlaOp range = BitcastConvertType(maxval, unsigned_type) -
505 BitcastConvertType(minval, unsigned_type);
506 XlaOp dist = Rem(bits, range);
507 XlaOp dist_div_2 =
508 ShiftRightLogical(dist, ConstantR0WithType(builder, unsigned_type, 1));
509
510 return minval + BitcastConvertType(dist_div_2, type) +
511 BitcastConvertType(dist - dist_div_2, type);
512 }
513
514 // Implements the Box-Muller transform, which converts random floats in the
515 // range of [0, 1] from uniform distribution to normal distribution with mean 0
516 // and variance 1. For more detail on the Box-Muller transform, see
517 // http://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform#Basic_form
BoxMullerTransform(XlaOp x0,XlaOp x1)518 std::pair<XlaOp, XlaOp> BoxMullerTransform(XlaOp x0, XlaOp x1) {
519 // Do not send a really small number to log().
520 XlaOp u1 = Max(x0, ScalarLike(x0, 1.0e-7f));
521
522 XlaOp v1 = ScalarLike(x1, 2.0f * M_PI) * x1;
523 XlaOp u2 = Sqrt(ScalarLike(u1, -2.0f) * Log(u1));
524 return {Sin(v1) * u2, Cos(v1) * u2};
525 }
526
527 } // namespace
528
PhiloxIncreaseCounter(XlaOp counter,XlaOp delta)529 XlaOp PhiloxIncreaseCounter(XlaOp counter, XlaOp delta) {
530 return Uint128ToOp(Uint128AddUint64(Uint128FromOp(counter), delta));
531 }
532
ThreeFryBitGenerator(XlaOp key,XlaOp initial_state,const Shape & shape)533 RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
534 const Shape& shape) {
535 PrimitiveType type = shape.element_type();
536 switch (type) {
537 case F32:
538 case U32:
539 case S32:
540 return ThreeFryRngBit32(key, initial_state, shape);
541 case F64:
542 case U64:
543 case S64:
544 return ThreeFryRngBit64(key, initial_state, shape);
545 default:
546 return {key.builder()->ReportError(Unimplemented(
547 "Types other than F32, F64, U32, S32, U64 and S64 "
548 "are not implemented by ThreeFryBitGenerator; got %s",
549 primitive_util::LowercasePrimitiveTypeName(type))),
550 initial_state};
551 }
552 }
553
PhiloxBitGenerator(XlaOp key,XlaOp initial_state,const Shape & shape)554 RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state,
555 const Shape& shape) {
556 PrimitiveType type = shape.element_type();
557 switch (type) {
558 case F32:
559 case U32:
560 case S32:
561 return PhiloxRngBit32(key, initial_state, shape);
562 case F64:
563 case U64:
564 case S64:
565 return PhiloxRngBit64(key, initial_state, shape);
566 default:
567 return {key.builder()->ReportError(Unimplemented(
568 "Types other than F32, F64, U32, S32, U64 and S64 "
569 "are not implemented by PhiloxFryBitGenerator; got %s",
570 primitive_util::LowercasePrimitiveTypeName(type))),
571 initial_state};
572 }
573 }
574
ScramblePhiloxKey(XlaOp key)575 std::pair<XlaOp, XlaOp> ScramblePhiloxKey(XlaOp key) {
576 Philox4x32Key pkey = Uint64ToUint32s(key);
577 auto state_key = ScramblePhiloxKey(pkey);
578 return std::make_pair(Uint128ToOp(Uint32sToUint128(state_key.first)),
579 Uint32sToUint64(state_key.second));
580 }
581
UniformFloatingPointDistribution(XlaOp key,XlaOp initial_state,BitGeneratorTy bit_generator,XlaOp minval,XlaOp maxval,const Shape & shape)582 RngOutput UniformFloatingPointDistribution(XlaOp key, XlaOp initial_state,
583 BitGeneratorTy bit_generator,
584 XlaOp minval, XlaOp maxval,
585 const Shape& shape) {
586 RngOutput bits_state = bit_generator(key, initial_state, shape);
587 XlaOp bits = bits_state.value;
588 XlaOp new_state = bits_state.state;
589 return {ConvertRandomBitsToUniformFloatingPoint(bits, minval, maxval),
590 new_state};
591 }
592
UniformIntDistribution(XlaOp key,XlaOp initial_state,BitGeneratorTy bit_generator,XlaOp minval,XlaOp maxval,const Shape & shape)593 RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state,
594 BitGeneratorTy bit_generator, XlaOp minval,
595 XlaOp maxval, const Shape& shape) {
596 RngOutput bits_state = bit_generator(key, initial_state, shape);
597 XlaOp bits = bits_state.value;
598 XlaOp new_state = bits_state.state;
599 PrimitiveType type = shape.element_type();
600 PrimitiveType unsigned_type;
601 if (type == U32 || type == S32) {
602 unsigned_type = U32;
603 } else {
604 DCHECK(type == U64 || type == S64);
605 unsigned_type = U64;
606 }
607 return {
608 ConvertRandomBitsToUniformInt(bits, minval, maxval, type, unsigned_type),
609 new_state};
610 }
611
NormalFloatingPointDistribution(XlaOp key,XlaOp initial_state,BitGeneratorTy bit_generator,const Shape & shape)612 RngOutput NormalFloatingPointDistribution(XlaOp key, XlaOp initial_state,
613 BitGeneratorTy bit_generator,
614 const Shape& shape) {
615 PrimitiveType primitive_type = shape.element_type();
616 DCHECK(primitive_type == F32 || primitive_type == F64);
617
618 XlaBuilder* builder = key.builder();
619 auto shape_pair = SplitShapeIntoHalves(shape);
620 RngOutput bits_state = UniformFloatingPointDistribution(
621 key, initial_state, bit_generator,
622 xla::ConstantR0WithType(builder, primitive_type, 0.0),
623 xla::ConstantR0WithType(builder, primitive_type, 1.0),
624 shape_pair.concat_shape);
625
626 // Separate the bits into two groups to perform the Box-Muller transform.
627 XlaOp bits_0 = Slice(bits_state.value,
628 std::vector<int64>(shape_pair.half_shape.rank(), 0),
629 shape_pair.half_shape.dimensions(),
630 std::vector<int64>(shape_pair.half_shape.rank(), 1));
631 std::vector<int64> bits_1_starts(shape_pair.half_shape.rank(), 0);
632 bits_1_starts[shape_pair.new_concat_dim] = 1;
633 XlaOp bits_1 = Slice(bits_state.value, bits_1_starts,
634 shape_pair.concat_shape.dimensions(),
635 std::vector<int64>(shape_pair.half_shape.rank(), 1));
636 std::tie(bits_0, bits_1) = BoxMullerTransform(bits_0, bits_1);
637
638 // Put the numbers in the two groups back to form the requested shape.
639 XlaOp normal = CombineShapePair({bits_0, bits_1}, shape_pair, shape);
640 return {normal, bits_state.state};
641 }
642
643 } // namespace xla
644