1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/prng.h"
17
18 #include <cmath>
19 #include <vector>
20
21 #include "absl/base/casts.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/lib/math.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/util.h"
26
27 namespace xla {
28
ConcatScalars(xla::XlaBuilder * builder,absl::Span<const xla::XlaOp> scalars)29 xla::XlaOp ConcatScalars(xla::XlaBuilder* builder,
30 absl::Span<const xla::XlaOp> scalars) {
31 std::vector<xla::XlaOp> vectors;
32 absl::c_transform(scalars, std::back_inserter(vectors),
33 [](xla::XlaOp x) { return xla::Reshape(x, {1}); });
34 return ConcatInDim(builder, vectors, 0);
35 }
36
37 namespace {
38
39 // Rotates a 32-bit integer 'v' left by 'distance' bits.
RotateLeftU32(XlaOp v,int distance)40 XlaOp RotateLeftU32(XlaOp v, int distance) {
41 return (v << ConstantR0<uint32>(v.builder(), distance)) |
42 ShiftRightLogical(v, ConstantR0<uint32>(v.builder(), 32 - distance));
43 }
44
45 // The internal state of the Three Fry implementation.
46 using ThreeFry2x32State = std::array<XlaOp, 2>;
47
48 // Implements the ThreeFry counter-based PRNG algorithm.
49 // Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
50 // http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
ThreeFry2x32(ThreeFry2x32State input,ThreeFry2x32State key)51 ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
52 XlaBuilder* builder = input[0].builder();
53 key[0] = BitcastConvertType(key[0], U32);
54 key[1] = BitcastConvertType(key[1], U32);
55
56 // Rotation distances specified by the Threefry2x32 algorithm.
57 constexpr std::array<int, 8> rotations = {13, 15, 26, 6, 17, 29, 16, 24};
58 ThreeFry2x32State x;
59
60 std::array<XlaOp, 3> ks;
61 // 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm.
62 ks[2] = ConstantR0<uint32>(builder, 0x1BD11BDA);
63 for (int i = 0; i < 2; ++i) {
64 ks[i] = key[i];
65 x[i] = input[i];
66 ks[2] = ks[2] ^ key[i];
67 }
68
69 x[0] = x[0] + ks[0];
70 x[1] = x[1] + ks[1];
71
72 // Performs a single round of the Threefry2x32 algorithm, with a rotation
73 // amount 'rotation'.
74 auto round = [](ThreeFry2x32State v, int rotation) {
75 v[0] = v[0] + v[1];
76 v[1] = RotateLeftU32(v[1], rotation);
77 v[1] = v[0] ^ v[1];
78 return v;
79 };
80
81 // There are no known statistical flaws with 13 rounds of Threefry2x32.
82 // We are conservative and use 20 rounds.
83 x = round(x, rotations[0]);
84 x = round(x, rotations[1]);
85 x = round(x, rotations[2]);
86 x = round(x, rotations[3]);
87 x[0] = x[0] + ks[1];
88 x[1] = x[1] + ks[2] + ConstantR0<uint32>(builder, 1);
89
90 x = round(x, rotations[4]);
91 x = round(x, rotations[5]);
92 x = round(x, rotations[6]);
93 x = round(x, rotations[7]);
94 x[0] = x[0] + ks[2];
95 x[1] = x[1] + ks[0] + ConstantR0<uint32>(builder, 2);
96
97 x = round(x, rotations[0]);
98 x = round(x, rotations[1]);
99 x = round(x, rotations[2]);
100 x = round(x, rotations[3]);
101 x[0] = x[0] + ks[0];
102 x[1] = x[1] + ks[1] + ConstantR0<uint32>(builder, 3);
103
104 x = round(x, rotations[4]);
105 x = round(x, rotations[5]);
106 x = round(x, rotations[6]);
107 x = round(x, rotations[7]);
108 x[0] = x[0] + ks[1];
109 x[1] = x[1] + ks[2] + ConstantR0<uint32>(builder, 4);
110
111 x = round(x, rotations[0]);
112 x = round(x, rotations[1]);
113 x = round(x, rotations[2]);
114 x = round(x, rotations[3]);
115 x[0] = x[0] + ks[2];
116 x[1] = x[1] + ks[0] + ConstantR0<uint32>(builder, 5);
117
118 return x;
119 }
120
121 // Converts a uint64 to two uint32s.
Uint64ToUint32s(XlaOp u64)122 std::array<XlaOp, 2> Uint64ToUint32s(XlaOp u64) {
123 XlaBuilder* builder = u64.builder();
124 XlaOp const32 = ConstantR0WithType(builder, U64, 32);
125 XlaOp fst = ConvertElementType(u64, U32);
126 XlaOp snd = ConvertElementType(ShiftRightLogical(u64, const32), U32);
127 return {fst, snd};
128 }
129
130 // Converts two uint32s to a uint64.
Uint32sToUint64(std::array<XlaOp,2> u32s)131 XlaOp Uint32sToUint64(std::array<XlaOp, 2> u32s) {
132 XlaBuilder* builder = u32s[0].builder();
133 return ConvertElementType(u32s[0], U64) |
134 ShiftLeft(ConvertElementType(u32s[1], U64),
135 ConstantR0WithType(builder, U64, 32));
136 }
137
138 // Given the initial state and the request shape of random numbers to be
139 // generated, returns the input for the random number generator and a new state.
GetThreeFryInputsAndUpdatedState(XlaOp initial_state,const Shape & shape)140 std::pair<ThreeFry2x32State, XlaOp> GetThreeFryInputsAndUpdatedState(
141 XlaOp initial_state, const Shape& shape) {
142 XlaBuilder* builder = initial_state.builder();
143 auto u64_shape = ShapeUtil::MakeShape(U64, shape.dimensions());
144 // initial_state is an R1, so reshape it to a scalar.
145 auto input_u64 = Broadcast(Reshape(initial_state, {}), shape.dimensions());
146 int64 trailing_dims_product = 1;
147 for (int64 i = shape.rank() - 1; i >= 0; --i) {
148 if (shape.dimensions(i) < 2) {
149 continue;
150 }
151 input_u64 =
152 input_u64 + (Iota(builder, u64_shape, i) *
153 ConstantR0<uint64>(builder, trailing_dims_product));
154 trailing_dims_product *= shape.dimensions(i);
155 }
156 XlaOp new_state =
157 initial_state + ConstantR0<uint64>(builder, ShapeUtil::ElementsIn(shape));
158 return std::make_pair(Uint64ToUint32s(input_u64), new_state);
159 }
160
161 // Result for SplitShapeIntoHalves().
162 struct SplitShapePair {
163 Shape half_shape;
164 Shape concat_shape;
165 int64 split_dim;
166 int64 new_concat_dim;
167 };
168
169 // Split the shape on a dimension > 1 into two halves.
SplitShapeIntoHalves(const Shape & shape)170 SplitShapePair SplitShapeIntoHalves(const Shape& shape) {
171 SplitShapePair pair;
172 if (shape.rank() == 0) {
173 pair.half_shape = ShapeUtil::MakeShape(shape.element_type(), {1});
174 pair.concat_shape = ShapeUtil::MakeShape(shape.element_type(), {2});
175 pair.split_dim = 0;
176 pair.new_concat_dim = 0;
177 return pair;
178 }
179 pair.split_dim = -1;
180 for (int64 i = 0; i < shape.rank(); ++i) {
181 if (shape.dimensions(i) % 2 == 0) {
182 pair.split_dim = i;
183 break;
184 }
185 }
186 if (pair.split_dim == -1) {
187 // No even dims. Find a dimension with maximum size.
188 for (int64 i = 0; i < shape.rank(); ++i) {
189 if (pair.split_dim == -1 ||
190 shape.dimensions(i) > shape.dimensions(pair.split_dim)) {
191 pair.split_dim = i;
192 }
193 }
194 }
195 CHECK_GE(pair.split_dim, 0);
196 std::vector<int64> half_shape_dims;
197 std::vector<int64> concat_shape_dims;
198 for (int64 i = 0; i < shape.rank(); ++i) {
199 if (i == pair.split_dim) {
200 // Create a new trivial dim for the later concat, which is more friendly
201 // to sharding propagation.
202 half_shape_dims.push_back(CeilOfRatio<int64>(shape.dimensions(i), 2));
203 half_shape_dims.push_back(1);
204 concat_shape_dims.push_back(half_shape_dims[i]);
205 concat_shape_dims.push_back(2);
206 } else {
207 half_shape_dims.push_back(shape.dimensions(i));
208 concat_shape_dims.push_back(shape.dimensions(i));
209 }
210 }
211 pair.new_concat_dim = pair.split_dim + 1;
212 pair.half_shape = ShapeUtil::MakeShape(shape.element_type(), half_shape_dims);
213 pair.concat_shape =
214 ShapeUtil::MakeShape(shape.element_type(), concat_shape_dims);
215 return pair;
216 }
217
218 // Combines a pair of split shapes. It works with scalar and non-scalar shapes.
CombineShapePair(absl::Span<const XlaOp> pair,const SplitShapePair & shape_pair,const Shape & original_shape)219 XlaOp CombineShapePair(absl::Span<const XlaOp> pair,
220 const SplitShapePair& shape_pair,
221 const Shape& original_shape) {
222 if (original_shape.rank() == 0) {
223 return Reshape(pair[0], {});
224 }
225 XlaBuilder* builder = pair[0].builder();
226 XlaOp result = ConcatInDim(builder, pair, shape_pair.new_concat_dim);
227 const int64 pre_split_size = original_shape.dimensions(shape_pair.split_dim);
228 std::vector<int64> reshape_dims(original_shape.dimensions().begin(),
229 original_shape.dimensions().end());
230 reshape_dims[shape_pair.split_dim] =
231 RoundUpToNearest<int64>(pre_split_size, 2);
232 result = Reshape(result, reshape_dims);
233 if (reshape_dims[shape_pair.split_dim] != pre_split_size) {
234 result = Slice(result, std::vector<int64>(original_shape.rank(), 0),
235 original_shape.dimensions(),
236 std::vector<int64>(original_shape.rank(), 1));
237 }
238 return result;
239 }
240
241 // Generates random 32bits with the given shape using the Three Fry
242 // implementation. Returns the random bits and the new state.
ThreeFryRngBit32(XlaOp key,XlaOp initial_state,const Shape & shape)243 RngOutput ThreeFryRngBit32(XlaOp key, XlaOp initial_state, const Shape& shape) {
244 auto shape_pair = SplitShapeIntoHalves(shape);
245 std::pair<ThreeFry2x32State, XlaOp> inputs_state =
246 GetThreeFryInputsAndUpdatedState(initial_state, shape_pair.half_shape);
247 ThreeFry2x32State inputs = inputs_state.first;
248 ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
249 XlaOp result = CombineShapePair(outputs, shape_pair, shape);
250 return {result, inputs_state.second};
251 }
252
253 // Generates random 64bits with the given shape using the Three Fry
254 // implementation. Returns the random bits and the new state.
ThreeFryRngBit64(XlaOp key,XlaOp initial_state,const Shape & shape)255 RngOutput ThreeFryRngBit64(XlaOp key, XlaOp initial_state, const Shape& shape) {
256 std::pair<ThreeFry2x32State, XlaOp> inputs_state =
257 GetThreeFryInputsAndUpdatedState(initial_state, shape);
258 ThreeFry2x32State inputs = inputs_state.first;
259 ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
260 XlaOp result = Uint32sToUint64(outputs);
261 return {result, inputs_state.second};
262 }
263
264 // The key of the Philox random number generator.
265 using Philox4x32Key = std::array<XlaOp, 2>;
266 // The internal state of the Philox random number generator.
267 using Philox4x32State = std::array<XlaOp, 4>;
268
269 // Computes the Philox4x32 algorithm using 10 rounds.
Philox4x32(Philox4x32State state,Philox4x32Key key)270 Philox4x32State Philox4x32(Philox4x32State state, Philox4x32Key key) {
271 // Constants specified by the Philox algorithm.
272 static const uint32 kPhiloxW32A = 0x9E3779B9;
273 static const uint32 kPhiloxW32B = 0xBB67AE85;
274 static const uint32 kPhiloxM4x32A = 0xD2511F53;
275 static const uint32 kPhiloxM4x32B = 0xCD9E8D57;
276
277 struct HighLowPair {
278 XlaOp high;
279 XlaOp low;
280 };
281
282 // Compute the high and low words from multiplying two 32-bit integers.
283 auto mul_hi_low = [](XlaOp x, uint32 k) {
284 auto product =
285 ConvertElementType(x, U64) * ConstantR0<uint64>(x.builder(), k);
286 auto low = ConvertElementType(product, U32);
287 auto high =
288 ConvertElementType(product >> ConstantR0<uint64>(x.builder(), 32), U32);
289 return HighLowPair{high, low};
290 };
291
292 // Perform a single round of the Philox algorithm.
293 auto philox_round = [&](Philox4x32State x, Philox4x32Key key) {
294 auto product0 = mul_hi_low(x[0], kPhiloxM4x32A);
295 auto product1 = mul_hi_low(x[2], kPhiloxM4x32B);
296 return Philox4x32State{product1.high ^ x[1] ^ key[0], product1.low,
297 product0.high ^ x[3] ^ key[1], product0.low};
298 };
299
300 // Update the key after a round of Philox algorithm.
301 auto raise_key = [](Philox4x32Key key) {
302 XlaBuilder* builder = key[0].builder();
303 return Philox4x32Key{key[0] + ConstantR0<uint32>(builder, kPhiloxW32A),
304 key[1] + ConstantR0<uint32>(builder, kPhiloxW32B)};
305 };
306
307 static const int kNumRounds = 10;
308 for (int round = 0; round < kNumRounds; ++round, key = raise_key(key)) {
309 state = philox_round(state, key);
310 }
311 return state;
312 }
313
314 // Scrambles the input key so that users don't need to worry about which part
315 // of the key needs to be strong.
ScramblePhiloxKey(Philox4x32Key key)316 std::pair<Philox4x32State, Philox4x32Key> ScramblePhiloxKey(Philox4x32Key key) {
317 XlaBuilder* builder = key[0].builder();
318 XlaOp key0 = ConvertElementType(key[0], U64);
319 XlaOp key1 = ConvertElementType(key[1], U64);
320
321 Philox4x32State state = {
322 ConvertElementType(key0, U32),
323 ConvertElementType(key0 >> ScalarLike(key0, 32), U32),
324 ConvertElementType(key1, U32),
325 ConvertElementType(key1 >> ScalarLike(key1, 32), U32),
326 };
327 key = {ConstantR0<uint32>(builder, 0x3ec8f720),
328 ConstantR0<uint32>(builder, 0x02461e29)};
329 state = Philox4x32(state, key);
330 XlaOp zero = ConstantR0<uint32>(builder, 0);
331 return {Philox4x32State{zero, zero, state[2], state[3]},
332 Philox4x32Key{state[0], state[1]}};
333 }
334
335 // Adds an U128 tensor with an U64 tensor. The U128 tensor is represented as two
336 // U64s with the low 64bits in the front. This routine supports explicit
337 // broadcasting of the U128 tensor, with `broadcast_sizes` representing the
338 // dimensions prepended to its shape.
Uint128AddUint64(const std::array<XlaOp,2> & u128,XlaOp u64,absl::Span<const int64> broadcast_sizes={})339 std::array<XlaOp, 2> Uint128AddUint64(
340 const std::array<XlaOp, 2>& u128, XlaOp u64,
341 absl::Span<const int64> broadcast_sizes = {}) {
342 auto u128_low = u128[0];
343 auto u128_high = u128[1];
344 XlaOp new_u128_low = u128_low + u64;
345 XlaOp one = ConstantR0<uint64>(u128[0].builder(), 1);
346 XlaOp new_u128_high = Select(Lt(new_u128_low, u128_low),
347 Broadcast(u128_high + one, broadcast_sizes),
348 Broadcast(u128_high, broadcast_sizes));
349 return {new_u128_low, new_u128_high};
350 }
351
Uint32sToUint128(const std::array<XlaOp,4> & u32s)352 std::array<XlaOp, 2> Uint32sToUint128(const std::array<XlaOp, 4>& u32s) {
353 return {Uint32sToUint64({u32s[0], u32s[1]}),
354 Uint32sToUint64({u32s[2], u32s[3]})};
355 }
356
Uint128ToUint32s(const std::array<XlaOp,2> & u128)357 std::array<XlaOp, 4> Uint128ToUint32s(const std::array<XlaOp, 2>& u128) {
358 std::array<XlaOp, 2> u128_low_32s = Uint64ToUint32s(u128[0]);
359 std::array<XlaOp, 2> u128_high_32s = Uint64ToUint32s(u128[1]);
360 return {u128_low_32s[0], u128_low_32s[1], u128_high_32s[0], u128_high_32s[1]};
361 }
362
Uint128FromOp(XlaOp op)363 std::array<XlaOp, 2> Uint128FromOp(XlaOp op) {
364 auto u128_low = xla::Reshape(xla::Slice(op, {0}, {1}, {1}), {});
365 auto u128_high = xla::Reshape(xla::Slice(op, {1}, {2}, {1}), {});
366 return {u128_low, u128_high};
367 }
368
Uint128ToOp(std::array<XlaOp,2> u128)369 XlaOp Uint128ToOp(std::array<XlaOp, 2> u128) {
370 return ConcatScalars(u128[0].builder(), {u128[0], u128[1]});
371 }
372
373 // Returns the pair (state + [0, 1, ..., n-1], state + n), which should be used
374 // as the inputs fed to `Philox4x32` and the updated state. `state` is an U128
375 // represented as 4 U32s in the order from the least significant one to the most
376 // significant one.
GetPhiloxInputsAndUpdatedState(const Philox4x32State & state,int64 n)377 std::pair<Philox4x32State, XlaOp> GetPhiloxInputsAndUpdatedState(
378 const Philox4x32State& state, int64 n) {
379 XlaBuilder* builder = state[0].builder();
380 XlaOp iota = Iota(builder, U64, n);
381 auto state_u128 = Uint32sToUint128(state);
382 auto inputs = Uint128ToUint32s(Uint128AddUint64(state_u128, iota, {n}));
383 XlaOp new_state =
384 Uint128ToOp(Uint128AddUint64(state_u128, ConstantR0<uint64>(builder, n)));
385 return std::make_pair(inputs, new_state);
386 }
387
388 // Generates CeilOfRatio(num_elems, 4)*4 32bit Philox random numbers, as Philox
389 // numbers are generated in the unit of 128bits.
GeneratePhiloxBits(int64 num_elems,XlaOp initial_state,Philox4x32Key key)390 std::pair<Philox4x32State, XlaOp> GeneratePhiloxBits(int64 num_elems,
391 XlaOp initial_state,
392 Philox4x32Key key) {
393 Philox4x32State state;
394 state = Uint128ToUint32s(Uint128FromOp(initial_state));
395 const int64 num_vector4 = CeilOfRatio<int64>(num_elems, 4);
396 Philox4x32State inputs;
397 XlaOp new_state;
398 std::tie(inputs, new_state) =
399 GetPhiloxInputsAndUpdatedState(state, num_vector4);
400 auto outputs = Philox4x32(inputs, key);
401 return std::make_pair(outputs, new_state);
402 }
403
404 // Generates an array of primitive type U32 with the given shape containing
405 // random bits generated by the Philox algorithm. Returns the array and the new
406 // state of the random number generator.
PhiloxRngBit32(XlaOp op_key,XlaOp initial_state,const Shape & shape)407 RngOutput PhiloxRngBit32(XlaOp op_key, XlaOp initial_state,
408 const Shape& shape) {
409 XlaBuilder* builder = op_key.builder();
410 const int64 num_elems = ShapeUtil::ElementsIn(shape);
411
412 Philox4x32Key key = Uint64ToUint32s(op_key);
413 Philox4x32State bits;
414 XlaOp new_state;
415 std::tie(bits, new_state) = GeneratePhiloxBits(num_elems, initial_state, key);
416 // Combining bits[i] in a round-robin fashion, to align with non-XLA
417 // implementations
418 int64 bits_len = (num_elems + 3) / 4;
419 for (auto i = 0; i < 4; ++i) {
420 bits[i] = Reshape(bits[i], {bits_len, 1});
421 }
422 XlaOp numbers = ConcatInDim(builder, {bits[0], bits[1], bits[2], bits[3]},
423 /*dimension=*/1);
424 numbers = Reshape(numbers, {bits_len * 4});
425 numbers = Slice(numbers, /*start_indices=*/{0},
426 /*limit_indices=*/{num_elems},
427 /*strides=*/{1});
428 return {Reshape(numbers, AsInt64Slice(shape.dimensions())), new_state};
429 }
430
431 // Generates an array of primitive type U64 with the given shape containing
432 // random bits generated by the Philox algorithm. Returns the array and the new
433 // state of the random number generator.
PhiloxRngBit64(XlaOp op_key,XlaOp initial_state,const Shape & shape)434 RngOutput PhiloxRngBit64(XlaOp op_key, XlaOp initial_state,
435 const Shape& shape) {
436 XlaBuilder* builder = op_key.builder();
437 const int64 num_elems = ShapeUtil::ElementsIn(shape);
438
439 Philox4x32Key key = Uint64ToUint32s(op_key);
440 Philox4x32State bits32;
441 XlaOp new_state;
442 std::tie(bits32, new_state) =
443 GeneratePhiloxBits(num_elems * 2, initial_state, key);
444
445 std::array<XlaOp, 2> bits64;
446 bits64[0] = Uint32sToUint64({bits32[0], bits32[1]});
447 bits64[1] = Uint32sToUint64({bits32[2], bits32[3]});
448
449 // Combining bits64[i] in a round-robin fashion, to align with non-XLA
450 // implementations
451 int64 bits64_len = (num_elems + 1) / 2;
452 for (auto i = 0; i < 2; ++i) {
453 bits64[i] = Reshape(bits64[i], {bits64_len, 1});
454 }
455 XlaOp numbers = ConcatInDim(builder, {bits64[0], bits64[1]},
456 /*dimension=*/1);
457 numbers = Reshape(numbers, {bits64_len * 2});
458 numbers = Slice(numbers, /*start_indices=*/{0},
459 /*limit_indices=*/{num_elems},
460 /*strides=*/{1});
461 return {Reshape(numbers, AsInt64Slice(shape.dimensions())), new_state};
462 }
463
ConvertRandomBitsToUniformFloatingPoint(XlaOp bits,XlaOp minval,XlaOp maxval)464 XlaOp ConvertRandomBitsToUniformFloatingPoint(XlaOp bits, XlaOp minval,
465 XlaOp maxval) {
466 XlaBuilder* builder = bits.builder();
467 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
468 TF_ASSIGN_OR_RETURN(const Shape* minval_shape,
469 builder->GetShapePtr(minval));
470 TF_ASSIGN_OR_RETURN(const Shape* bits_shape, builder->GetShapePtr(bits));
471 PrimitiveType value_type = minval_shape->element_type();
472 PrimitiveType bit_type = bits_shape->element_type();
473 CHECK((value_type == F32 && bit_type == U32) ||
474 (value_type == F64 && bit_type == U64));
475
476 // Form random mantissa bits for float/double, with a leading 1 bit.
477 int num_float_bits = primitive_util::BitWidth(value_type);
478 // Subtract one as SignificandWidth includes the leading 1 bit.
479 int num_mantissa_bits = primitive_util::SignificandWidth(value_type) - 1;
480
481 // Ignore the exponent bits and convert the mantissa bits to the floating
482 // point type.
483 bits = ShiftRightLogical(
484 bits, ScalarLike(bits, num_float_bits - num_mantissa_bits));
485
486 // We have an integer-valued floating point number in the range
487 // [0, 2**{num_mantissa_bits}).
488 XlaOp values = ConvertElementType(bits, value_type);
489
490 // Divide by 2**{-num_mantissa_bits} to get a number in the range
491 // [0.0, 1.0).
492 values = values * ScalarLike(values, std::ldexp(1., -num_mantissa_bits));
493
494 // Multiply and add to shift to the range [minval, maxval).
495 return values * (maxval - minval) + minval;
496 });
497 }
498
ConvertRandomBitsToUniformInt(XlaOp bits,XlaOp minval,XlaOp maxval,PrimitiveType type,PrimitiveType unsigned_type)499 XlaOp ConvertRandomBitsToUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
500 PrimitiveType type,
501 PrimitiveType unsigned_type) {
502 XlaBuilder* builder = bits.builder();
503 XlaOp range = BitcastConvertType(maxval, unsigned_type) -
504 BitcastConvertType(minval, unsigned_type);
505 XlaOp dist = Rem(bits, range);
506 XlaOp dist_div_2 =
507 ShiftRightLogical(dist, ConstantR0WithType(builder, unsigned_type, 1));
508
509 return minval + BitcastConvertType(dist_div_2, type) +
510 BitcastConvertType(dist - dist_div_2, type);
511 }
512
513 // Implements the Box-Muller transform, which converts random floats in the
514 // range of [0, 1] from uniform distribution to normal distribution with mean 0
515 // and variance 1. For more detail on the Box-Muller transform, see
516 // http://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform#Basic_form
BoxMullerTransform(XlaOp x0,XlaOp x1)517 std::pair<XlaOp, XlaOp> BoxMullerTransform(XlaOp x0, XlaOp x1) {
518 // Do not send a really small number to log().
519 XlaOp u1 = Max(x0, ScalarLike(x0, 1.0e-7f));
520
521 XlaOp v1 = ScalarLike(x1, 2.0f * M_PI) * x1;
522 XlaOp u2 = Sqrt(ScalarLike(u1, -2.0f) * Log(u1));
523 return {Sin(v1) * u2, Cos(v1) * u2};
524 }
525
526 } // namespace
527
PhiloxIncreaseCounter(XlaOp counter,XlaOp delta)528 XlaOp PhiloxIncreaseCounter(XlaOp counter, XlaOp delta) {
529 return Uint128ToOp(Uint128AddUint64(Uint128FromOp(counter), delta));
530 }
531
ThreeFryBitGenerator(XlaOp key,XlaOp initial_state,const Shape & shape)532 RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
533 const Shape& shape) {
534 PrimitiveType type = shape.element_type();
535 switch (type) {
536 case F32:
537 case U32:
538 case S32:
539 return ThreeFryRngBit32(key, initial_state, shape);
540 case F64:
541 case U64:
542 case S64:
543 return ThreeFryRngBit64(key, initial_state, shape);
544 default:
545 return {key.builder()->ReportError(Unimplemented(
546 "Types other than F32, F64, U32, S32, U64 and S64 "
547 "are not implemented by ThreeFryBitGenerator; got %s",
548 primitive_util::LowercasePrimitiveTypeName(type))),
549 initial_state};
550 }
551 }
552
PhiloxBitGenerator(XlaOp key,XlaOp initial_state,const Shape & shape)553 RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state,
554 const Shape& shape) {
555 PrimitiveType type = shape.element_type();
556 switch (type) {
557 case F32:
558 case U32:
559 case S32:
560 return PhiloxRngBit32(key, initial_state, shape);
561 case F64:
562 case U64:
563 case S64:
564 return PhiloxRngBit64(key, initial_state, shape);
565 default:
566 return {key.builder()->ReportError(Unimplemented(
567 "Types other than F32, F64, U32, S32, U64 and S64 "
568 "are not implemented by PhiloxFryBitGenerator; got %s",
569 primitive_util::LowercasePrimitiveTypeName(type))),
570 initial_state};
571 }
572 }
573
ScramblePhiloxKey(XlaOp key)574 std::pair<XlaOp, XlaOp> ScramblePhiloxKey(XlaOp key) {
575 Philox4x32Key pkey = Uint64ToUint32s(key);
576 auto state_key = ScramblePhiloxKey(pkey);
577 return std::make_pair(Uint128ToOp(Uint32sToUint128(state_key.first)),
578 Uint32sToUint64(state_key.second));
579 }
580
UniformFloatingPointDistribution(XlaOp key,XlaOp initial_state,BitGeneratorTy bit_generator,XlaOp minval,XlaOp maxval,const Shape & shape)581 RngOutput UniformFloatingPointDistribution(XlaOp key, XlaOp initial_state,
582 BitGeneratorTy bit_generator,
583 XlaOp minval, XlaOp maxval,
584 const Shape& shape) {
585 RngOutput bits_state = bit_generator(key, initial_state, shape);
586 XlaOp bits = bits_state.value;
587 XlaOp new_state = bits_state.state;
588 return {ConvertRandomBitsToUniformFloatingPoint(bits, minval, maxval),
589 new_state};
590 }
591
UniformIntDistribution(XlaOp key,XlaOp initial_state,BitGeneratorTy bit_generator,XlaOp minval,XlaOp maxval,const Shape & shape)592 RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state,
593 BitGeneratorTy bit_generator, XlaOp minval,
594 XlaOp maxval, const Shape& shape) {
595 RngOutput bits_state = bit_generator(key, initial_state, shape);
596 XlaOp bits = bits_state.value;
597 XlaOp new_state = bits_state.state;
598 PrimitiveType type = shape.element_type();
599 PrimitiveType unsigned_type;
600 if (type == U32 || type == S32) {
601 unsigned_type = U32;
602 } else {
603 DCHECK(type == U64 || type == S64);
604 unsigned_type = U64;
605 }
606 return {
607 ConvertRandomBitsToUniformInt(bits, minval, maxval, type, unsigned_type),
608 new_state};
609 }
610
NormalFloatingPointDistribution(XlaOp key,XlaOp initial_state,BitGeneratorTy bit_generator,const Shape & shape)611 RngOutput NormalFloatingPointDistribution(XlaOp key, XlaOp initial_state,
612 BitGeneratorTy bit_generator,
613 const Shape& shape) {
614 PrimitiveType primitive_type = shape.element_type();
615 DCHECK(primitive_type == F32 || primitive_type == F64);
616
617 XlaBuilder* builder = key.builder();
618 auto shape_pair = SplitShapeIntoHalves(shape);
619 RngOutput bits_state = UniformFloatingPointDistribution(
620 key, initial_state, bit_generator,
621 xla::ConstantR0WithType(builder, primitive_type, 0.0),
622 xla::ConstantR0WithType(builder, primitive_type, 1.0),
623 shape_pair.concat_shape);
624
625 // Separate the bits into two groups to perform the Box-Muller transform.
626 XlaOp bits_0 = Slice(bits_state.value,
627 std::vector<int64>(shape_pair.half_shape.rank(), 0),
628 shape_pair.half_shape.dimensions(),
629 std::vector<int64>(shape_pair.half_shape.rank(), 1));
630 std::vector<int64> bits_1_starts(shape_pair.half_shape.rank(), 0);
631 bits_1_starts[shape_pair.new_concat_dim] = 1;
632 XlaOp bits_1 = Slice(bits_state.value, bits_1_starts,
633 shape_pair.concat_shape.dimensions(),
634 std::vector<int64>(shape_pair.half_shape.rank(), 1));
635 std::tie(bits_0, bits_1) = BoxMullerTransform(bits_0, bits_1);
636
637 // Put the numbers in the two groups back to form the requested shape.
638 XlaOp normal = CombineShapePair({bits_0, bits_1}, shape_pair, shape);
639 return {normal, bits_state.state};
640 }
641
642 } // namespace xla
643